diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 198645d875c47..2aa0f2117364c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -943,7 +943,7 @@ class Analyzer( failAnalysis("Invalid usage of '*' in explode/json_tuple/UDTF") // To resolve duplicate expression IDs for Join and Intersect - case j @ Join(left, right, _, _) if !j.duplicateResolved => + case j @ Join(left, right, _, _, _) if !j.duplicateResolved => j.copy(right = dedupRight(left, right)) case i @ Intersect(left, right, _) if !i.duplicateResolved => i.copy(right = dedupRight(left, right)) @@ -2249,13 +2249,14 @@ class Analyzer( */ object ResolveNaturalAndUsingJoin extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { - case j @ Join(left, right, UsingJoin(joinType, usingCols), _) + case j @ Join(left, right, UsingJoin(joinType, usingCols), _, hint) if left.resolved && right.resolved && j.duplicateResolved => - commonNaturalJoinProcessing(left, right, joinType, usingCols, None) - case j @ Join(left, right, NaturalJoin(joinType), condition) if j.resolvedExceptNatural => + commonNaturalJoinProcessing(left, right, joinType, usingCols, None, hint) + case j @ Join(left, right, NaturalJoin(joinType), condition, hint) + if j.resolvedExceptNatural => // find common column names from both sides val joinNames = left.output.map(_.name).intersect(right.output.map(_.name)) - commonNaturalJoinProcessing(left, right, joinType, joinNames, condition) + commonNaturalJoinProcessing(left, right, joinType, joinNames, condition, hint) } } @@ -2360,7 +2361,8 @@ class Analyzer( right: LogicalPlan, joinType: JoinType, joinNames: Seq[String], - condition: Option[Expression]) = { + condition: Option[Expression], + hint: JoinHint) = { val leftKeys = joinNames.map { keyName => left.output.find(attr => resolver(attr.name, keyName)).getOrElse { throw new AnalysisException(s"USING column `$keyName` cannot be resolved on the left " + @@ -2401,7 +2403,7 @@ class Analyzer( sys.error("Unsupported natural join type " + joinType) } // use Project to trim unnecessary fields - Project(projectList, Join(left, right, joinType, newCondition)) + Project(projectList, Join(left, right, joinType, newCondition, hint)) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index c28a97839fe49..18c40b370cb5f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -172,7 +172,7 @@ trait CheckAnalysis extends PredicateHelper { failAnalysis("Null-aware predicate sub-queries cannot be used in nested " + s"conditions: $condition") - case j @ Join(_, _, _, Some(condition)) if condition.dataType != BooleanType => + case j @ Join(_, _, _, Some(condition), _) if condition.dataType != BooleanType => failAnalysis( s"join condition '${condition.sql}' " + s"of type ${condition.dataType.catalogString} is not a boolean.") @@ -609,7 +609,7 @@ trait CheckAnalysis extends PredicateHelper { failOnNonEqualCorrelatedPredicate(foundNonEqualCorrelatedPred, a) // Join can host correlated expressions. - case j @ Join(left, right, joinType, _) => + case j @ Join(left, right, joinType, _, _) => joinType match { // Inner join, like Filter, can be anywhere. case _: InnerLike => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala index 7a0aa08289efa..76733dd6dac3c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala @@ -41,7 +41,7 @@ object StreamingJoinHelper extends PredicateHelper with Logging { */ def isWatermarkInJoinKeys(plan: LogicalPlan): Boolean = { plan match { - case ExtractEquiJoinKeys(_, leftKeys, rightKeys, _, _, _) => + case ExtractEquiJoinKeys(_, leftKeys, rightKeys, _, _, _, _) => (leftKeys ++ rightKeys).exists { case a: AttributeReference => a.metadata.contains(EventTimeWatermark.delayKey) case _ => false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index cff4cee09427f..41ba6d34b5499 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -229,7 +229,7 @@ object UnsupportedOperationChecker { throwError("dropDuplicates is not supported after aggregation on a " + "streaming DataFrame/Dataset") - case Join(left, right, joinType, condition) => + case Join(left, right, joinType, condition, _) => joinType match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 151481c80ee96..846ee3b386527 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -325,7 +325,7 @@ package object dsl { otherPlan: LogicalPlan, joinType: JoinType = Inner, condition: Option[Expression] = None): LogicalPlan = - Join(logicalPlan, otherPlan, joinType, condition) + Join(logicalPlan, otherPlan, joinType, condition, JoinHint.NONE) def cogroup[Key: Encoder, Left: Encoder, Right: Encoder, Result: Encoder]( otherPlan: LogicalPlan, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala index 01634a9d852c6..743d3ce944fe2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala @@ -22,7 +22,7 @@ import scala.collection.mutable import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeSet, Expression, PredicateHelper} import org.apache.spark.sql.catalyst.plans.{Inner, InnerLike, JoinType} -import org.apache.spark.sql.catalyst.plans.logical.{BinaryNode, Join, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.internal.SQLConf @@ -31,6 +31,40 @@ import org.apache.spark.sql.internal.SQLConf * Cost-based join reorder. * We may have several join reorder algorithms in the future. This class is the entry of these * algorithms, and chooses which one to use. + * + * Note that join strategy hints, e.g. the broadcast hint, do not interfere with the reordering. + * Such hints will be applied on the equivalent counterparts (i.e., join between the same relations + * regardless of the join order) of the original nodes after reordering. + * For example, the plan before reordering is like: + * + * Join + * / \ + * Hint1 t4 + * / + * Join + * / \ + * Join t3 + * / \ + * Hint2 t2 + * / + * t1 + * + * The original join order as illustrated above is "((t1 JOIN t2) JOIN t3) JOIN t4", and after + * reordering, the new join order is "((t1 JOIN t3) JOIN t2) JOIN t4", so the new plan will be like: + * + * Join + * / \ + * Hint1 t4 + * / + * Join + * / \ + * Join t2 + * / \ + * t1 t3 + * + * "Hint1" is applied on "(t1 JOIN t3) JOIN t2" as it is equivalent to the original hinted node, + * "(t1 JOIN t2) JOIN t3"; while "Hint2" has disappeared from the new plan since there is no + * equivalent node to "t1 JOIN t2". */ object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper { @@ -40,24 +74,30 @@ object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper { if (!conf.cboEnabled || !conf.joinReorderEnabled) { plan } else { + // Use a map to track the hints on the join items. + val hintMap = new mutable.HashMap[AttributeSet, HintInfo] val result = plan transformDown { // Start reordering with a joinable item, which is an InnerLike join with conditions. - case j @ Join(_, _, _: InnerLike, Some(cond)) => - reorder(j, j.output) - case p @ Project(projectList, Join(_, _, _: InnerLike, Some(cond))) + case j @ Join(_, _, _: InnerLike, Some(cond), _) => + reorder(j, j.output, hintMap) + case p @ Project(projectList, Join(_, _, _: InnerLike, Some(cond), _)) if projectList.forall(_.isInstanceOf[Attribute]) => - reorder(p, p.output) + reorder(p, p.output, hintMap) } - - // After reordering is finished, convert OrderedJoin back to Join - result transformDown { - case OrderedJoin(left, right, jt, cond) => Join(left, right, jt, cond) + // After reordering is finished, convert OrderedJoin back to Join. + result transform { + case OrderedJoin(left, right, jt, cond) => + val joinHint = JoinHint(hintMap.get(left.outputSet), hintMap.get(right.outputSet)) + Join(left, right, jt, cond, joinHint) } } } - private def reorder(plan: LogicalPlan, output: Seq[Attribute]): LogicalPlan = { - val (items, conditions) = extractInnerJoins(plan) + private def reorder( + plan: LogicalPlan, + output: Seq[Attribute], + hintMap: mutable.HashMap[AttributeSet, HintInfo]): LogicalPlan = { + val (items, conditions) = extractInnerJoins(plan, hintMap) val result = // Do reordering if the number of items is appropriate and join conditions exist. // We also need to check if costs of all items can be evaluated. @@ -75,27 +115,31 @@ object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper { * Extracts items of consecutive inner joins and join conditions. * This method works for bushy trees and left/right deep trees. */ - private def extractInnerJoins(plan: LogicalPlan): (Seq[LogicalPlan], Set[Expression]) = { + private def extractInnerJoins( + plan: LogicalPlan, + hintMap: mutable.HashMap[AttributeSet, HintInfo]): (Seq[LogicalPlan], Set[Expression]) = { plan match { - case Join(left, right, _: InnerLike, Some(cond)) => - val (leftPlans, leftConditions) = extractInnerJoins(left) - val (rightPlans, rightConditions) = extractInnerJoins(right) + case Join(left, right, _: InnerLike, Some(cond), hint) => + hint.leftHint.foreach(hintMap.put(left.outputSet, _)) + hint.rightHint.foreach(hintMap.put(right.outputSet, _)) + val (leftPlans, leftConditions) = extractInnerJoins(left, hintMap) + val (rightPlans, rightConditions) = extractInnerJoins(right, hintMap) (leftPlans ++ rightPlans, splitConjunctivePredicates(cond).toSet ++ leftConditions ++ rightConditions) - case Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond))) + case Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond), _)) if projectList.forall(_.isInstanceOf[Attribute]) => - extractInnerJoins(j) + extractInnerJoins(j, hintMap) case _ => (Seq(plan), Set()) } } private def replaceWithOrderedJoin(plan: LogicalPlan): LogicalPlan = plan match { - case j @ Join(left, right, jt: InnerLike, Some(cond)) => + case j @ Join(left, right, jt: InnerLike, Some(cond), _) => val replacedLeft = replaceWithOrderedJoin(left) val replacedRight = replaceWithOrderedJoin(right) OrderedJoin(replacedLeft, replacedRight, jt, Some(cond)) - case p @ Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond))) => + case p @ Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond), _)) => p.copy(child = replaceWithOrderedJoin(j)) case _ => plan @@ -295,7 +339,7 @@ object JoinReorderDP extends PredicateHelper with Logging { } else { (otherPlan, onePlan) } - val newJoin = Join(left, right, Inner, joinConds.reduceOption(And)) + val newJoin = Join(left, right, Inner, joinConds.reduceOption(And), JoinHint.NONE) val collectedJoinConds = joinConds ++ oneJoinPlan.joinConds ++ otherJoinPlan.joinConds val remainingConds = conditions -- collectedJoinConds val neededAttr = AttributeSet(remainingConds.flatMap(_.references)) ++ topOutput diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EliminateResolvedHint.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EliminateResolvedHint.scala new file mode 100644 index 0000000000000..bbe4eee4b4326 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EliminateResolvedHint.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.Rule + +/** + * Replaces [[ResolvedHint]] operators from the plan. Move the [[HintInfo]] to associated [[Join]] + * operators, otherwise remove it if no [[Join]] operator is matched. + */ +object EliminateResolvedHint extends Rule[LogicalPlan] { + // This is also called in the beginning of the optimization phase, and as a result + // is using transformUp rather than resolveOperators. + def apply(plan: LogicalPlan): LogicalPlan = { + val pulledUp = plan transformUp { + case j: Join => + val leftHint = mergeHints(collectHints(j.left)) + val rightHint = mergeHints(collectHints(j.right)) + j.copy(hint = JoinHint(leftHint, rightHint)) + } + pulledUp.transform { + case h: ResolvedHint => h.child + } + } + + private def mergeHints(hints: Seq[HintInfo]): Option[HintInfo] = { + hints.reduceOption((h1, h2) => HintInfo( + broadcast = h1.broadcast || h2.broadcast)) + } + + private def collectHints(plan: LogicalPlan): Seq[HintInfo] = { + plan match { + case h: ResolvedHint => collectHints(h.child) :+ h.hints + case u: UnaryNode => collectHints(u.child) + // TODO revisit this logic: + // except and intersect are semi/anti-joins which won't return more data then + // their left argument, so the broadcast hint should be propagated here + case i: Intersect => collectHints(i.left) + case e: Except => collectHints(e.left) + case _ => Seq.empty + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 44d5543114902..06f908281dd3c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -115,6 +115,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) // However, because we also use the analyzer to canonicalized queries (for view definition), // we do not eliminate subqueries or compute current time in the analyzer. Batch("Finish Analysis", Once, + EliminateResolvedHint, EliminateSubqueryAliases, EliminateView, ReplaceExpressions, @@ -192,6 +193,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) */ def nonExcludableRules: Seq[String] = EliminateDistinct.ruleName :: + EliminateResolvedHint.ruleName :: EliminateSubqueryAliases.ruleName :: EliminateView.ruleName :: ReplaceExpressions.ruleName :: @@ -356,7 +358,7 @@ object RemoveRedundantAliases extends Rule[LogicalPlan] { // not allowed to use the same attributes. We use a blacklist to prevent us from creating a // situation in which this happens; the rule will only remove an alias if its child // attribute is not on the black list. - case Join(left, right, joinType, condition) => + case Join(left, right, joinType, condition, hint) => val newLeft = removeRedundantAliases(left, blacklist ++ right.outputSet) val newRight = removeRedundantAliases(right, blacklist ++ newLeft.outputSet) val mapping = AttributeMap( @@ -365,7 +367,7 @@ object RemoveRedundantAliases extends Rule[LogicalPlan] { val newCondition = condition.map(_.transform { case a: Attribute => mapping.getOrElse(a, a) }) - Join(newLeft, newRight, joinType, newCondition) + Join(newLeft, newRight, joinType, newCondition, hint) case _ => // Remove redundant aliases in the subtree(s). @@ -460,7 +462,7 @@ object LimitPushDown extends Rule[LogicalPlan] { // on both sides if it is applied multiple times. Therefore: // - If one side is already limited, stack another limit on top if the new limit is smaller. // The redundant limit will be collapsed by the CombineLimits rule. - case LocalLimit(exp, join @ Join(left, right, joinType, _)) => + case LocalLimit(exp, join @ Join(left, right, joinType, _, _)) => val newJoin = joinType match { case RightOuter => join.copy(right = maybePushLocalLimit(exp, right)) case LeftOuter => join.copy(left = maybePushLocalLimit(exp, left)) @@ -578,7 +580,7 @@ object ColumnPruning extends Rule[LogicalPlan] { p.copy(child = g.copy(child = newChild, unrequiredChildIndex = unrequiredIndices)) // Eliminate unneeded attributes from right side of a Left Existence Join. - case j @ Join(_, right, LeftExistence(_), _) => + case j @ Join(_, right, LeftExistence(_), _, _) => j.copy(right = prunedChild(right, j.references)) // all the columns will be used to compare, so we can't prune them @@ -792,7 +794,7 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] filter } - case join @ Join(left, right, joinType, conditionOpt) => + case join @ Join(left, right, joinType, conditionOpt, _) => joinType match { // For inner join, we can infer additional filters for both sides. LeftSemi is kind of an // inner join, it just drops the right side in the final output. @@ -919,7 +921,6 @@ object RemoveRedundantSorts extends Rule[LogicalPlan] { def canEliminateSort(plan: LogicalPlan): Boolean = plan match { case p: Project => p.projectList.forall(_.deterministic) case f: Filter => f.condition.deterministic - case _: ResolvedHint => true case _ => false } } @@ -1094,7 +1095,6 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { // Note that some operators (e.g. project, aggregate, union) are being handled separately // (earlier in this rule). case _: AppendColumns => true - case _: ResolvedHint => true case _: Distinct => true case _: Generate => true case _: Pivot => true @@ -1179,7 +1179,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { // push the where condition down into join filter - case f @ Filter(filterCondition, Join(left, right, joinType, joinCondition)) => + case f @ Filter(filterCondition, Join(left, right, joinType, joinCondition, hint)) => val (leftFilterConditions, rightFilterConditions, commonFilterCondition) = split(splitConjunctivePredicates(filterCondition), left, right) joinType match { @@ -1193,7 +1193,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { commonFilterCondition.partition(canEvaluateWithinJoin) val newJoinCond = (newJoinConditions ++ joinCondition).reduceLeftOption(And) - val join = Join(newLeft, newRight, joinType, newJoinCond) + val join = Join(newLeft, newRight, joinType, newJoinCond, hint) if (others.nonEmpty) { Filter(others.reduceLeft(And), join) } else { @@ -1205,7 +1205,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { val newRight = rightFilterConditions. reduceLeftOption(And).map(Filter(_, right)).getOrElse(right) val newJoinCond = joinCondition - val newJoin = Join(newLeft, newRight, RightOuter, newJoinCond) + val newJoin = Join(newLeft, newRight, RightOuter, newJoinCond, hint) (leftFilterConditions ++ commonFilterCondition). reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin) @@ -1215,7 +1215,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { reduceLeftOption(And).map(Filter(_, left)).getOrElse(left) val newRight = right val newJoinCond = joinCondition - val newJoin = Join(newLeft, newRight, joinType, newJoinCond) + val newJoin = Join(newLeft, newRight, joinType, newJoinCond, hint) (rightFilterConditions ++ commonFilterCondition). reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin) @@ -1225,7 +1225,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { } // push down the join filter into sub query scanning if applicable - case j @ Join(left, right, joinType, joinCondition) => + case j @ Join(left, right, joinType, joinCondition, hint) => val (leftJoinConditions, rightJoinConditions, commonJoinCondition) = split(joinCondition.map(splitConjunctivePredicates).getOrElse(Nil), left, right) @@ -1238,7 +1238,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { reduceLeftOption(And).map(Filter(_, right)).getOrElse(right) val newJoinCond = commonJoinCondition.reduceLeftOption(And) - Join(newLeft, newRight, joinType, newJoinCond) + Join(newLeft, newRight, joinType, newJoinCond, hint) case RightOuter => // push down the left side only join filter for left side sub query val newLeft = leftJoinConditions. @@ -1246,7 +1246,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { val newRight = right val newJoinCond = (rightJoinConditions ++ commonJoinCondition).reduceLeftOption(And) - Join(newLeft, newRight, RightOuter, newJoinCond) + Join(newLeft, newRight, RightOuter, newJoinCond, hint) case LeftOuter | LeftAnti | ExistenceJoin(_) => // push down the right side only join filter for right sub query val newLeft = left @@ -1254,7 +1254,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { reduceLeftOption(And).map(Filter(_, right)).getOrElse(right) val newJoinCond = (leftJoinConditions ++ commonJoinCondition).reduceLeftOption(And) - Join(newLeft, newRight, joinType, newJoinCond) + Join(newLeft, newRight, joinType, newJoinCond, hint) case FullOuter => j case NaturalJoin(_) => sys.error("Untransformed NaturalJoin node") case UsingJoin(_, _) => sys.error("Untransformed Using join node") @@ -1310,7 +1310,7 @@ object CheckCartesianProducts extends Rule[LogicalPlan] with PredicateHelper { if (SQLConf.get.crossJoinEnabled) { plan } else plan transform { - case j @ Join(left, right, Inner | LeftOuter | RightOuter | FullOuter, _) + case j @ Join(left, right, Inner | LeftOuter | RightOuter | FullOuter, _, _) if isCartesianProduct(j) => throw new AnalysisException( s"""Detected implicit cartesian product for ${j.joinType.sql} join between logical plans @@ -1449,7 +1449,7 @@ object ReplaceIntersectWithSemiJoin extends Rule[LogicalPlan] { case Intersect(left, right, false) => assert(left.output.size == right.output.size) val joinCond = left.output.zip(right.output).map { case (l, r) => EqualNullSafe(l, r) } - Distinct(Join(left, right, LeftSemi, joinCond.reduceLeftOption(And))) + Distinct(Join(left, right, LeftSemi, joinCond.reduceLeftOption(And), JoinHint.NONE)) } } @@ -1470,7 +1470,7 @@ object ReplaceExceptWithAntiJoin extends Rule[LogicalPlan] { case Except(left, right, false) => assert(left.output.size == right.output.size) val joinCond = left.output.zip(right.output).map { case (l, r) => EqualNullSafe(l, r) } - Distinct(Join(left, right, LeftAnti, joinCond.reduceLeftOption(And))) + Distinct(Join(left, right, LeftAnti, joinCond.reduceLeftOption(And), JoinHint.NONE)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala index c3fdb924243df..b19e13870aa65 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala @@ -56,7 +56,7 @@ object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper wit // Joins on empty LocalRelations generated from streaming sources are not eliminated // as stateful streaming joins need to perform other state management operations other than // just processing the input data. - case p @ Join(_, _, joinType, _) + case p @ Join(_, _, joinType, _, _) if !p.children.exists(_.isStreaming) => val isLeftEmpty = isEmptyLocalRelation(p.left) val isRightEmpty = isEmptyLocalRelation(p.right) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala index 72a60f692ac78..689915a985343 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala @@ -52,7 +52,7 @@ object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case f @ Filter(cond, _) => f.copy(condition = replaceNullWithFalse(cond)) - case j @ Join(_, _, _, Some(cond)) => j.copy(condition = Some(replaceNullWithFalse(cond))) + case j @ Join(_, _, _, Some(cond), _) => j.copy(condition = Some(replaceNullWithFalse(cond))) case p: LogicalPlan => p transformExpressions { case i @ If(pred, _, _) => i.copy(predicate = replaceNullWithFalse(pred)) case cw @ CaseWhen(branches, _) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 468a950fb1087..39709529c00d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -600,7 +600,7 @@ object FoldablePropagation extends Rule[LogicalPlan] { // propagating the foldable expressions. // TODO(cloud-fan): It seems more reasonable to use new attributes as the output attributes // of outer join. - case j @ Join(left, right, joinType, _) if foldableMap.nonEmpty => + case j @ Join(left, right, joinType, _, _) if foldableMap.nonEmpty => val newJoin = j.transformExpressions(replaceFoldable) val missDerivedAttrsSet: AttributeSet = AttributeSet(joinType match { case _: InnerLike | LeftExistence(_) => Nil @@ -648,7 +648,6 @@ object FoldablePropagation extends Rule[LogicalPlan] { case _: Distinct => true case _: AppendColumns => true case _: AppendColumnsWithObject => true - case _: ResolvedHint => true case _: RepartitionByExpression => true case _: Repartition => true case _: Sort => true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index 0b6471289a471..82aefca8a1af6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -43,10 +43,13 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { * * @param input a list of LogicalPlans to inner join and the type of inner join. * @param conditions a list of condition for join. + * @param hintMap a map of relation output attribute sets to their corresponding hints. */ @tailrec - final def createOrderedJoin(input: Seq[(LogicalPlan, InnerLike)], conditions: Seq[Expression]) - : LogicalPlan = { + final def createOrderedJoin( + input: Seq[(LogicalPlan, InnerLike)], + conditions: Seq[Expression], + hintMap: Map[AttributeSet, HintInfo]): LogicalPlan = { assert(input.size >= 2) if (input.size == 2) { val (joinConditions, others) = conditions.partition(canEvaluateWithinJoin) @@ -55,7 +58,8 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { case (Inner, Inner) => Inner case (_, _) => Cross } - val join = Join(left, right, innerJoinType, joinConditions.reduceLeftOption(And)) + val join = Join(left, right, innerJoinType, joinConditions.reduceLeftOption(And), + JoinHint(hintMap.get(left.outputSet), hintMap.get(right.outputSet))) if (others.nonEmpty) { Filter(others.reduceLeft(And), join) } else { @@ -78,26 +82,27 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { val joinedRefs = left.outputSet ++ right.outputSet val (joinConditions, others) = conditions.partition( e => e.references.subsetOf(joinedRefs) && canEvaluateWithinJoin(e)) - val joined = Join(left, right, innerJoinType, joinConditions.reduceLeftOption(And)) + val joined = Join(left, right, innerJoinType, joinConditions.reduceLeftOption(And), + JoinHint(hintMap.get(left.outputSet), hintMap.get(right.outputSet))) // should not have reference to same logical plan - createOrderedJoin(Seq((joined, Inner)) ++ rest.filterNot(_._1 eq right), others) + createOrderedJoin(Seq((joined, Inner)) ++ rest.filterNot(_._1 eq right), others, hintMap) } } def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case p @ ExtractFiltersAndInnerJoins(input, conditions) + case p @ ExtractFiltersAndInnerJoins(input, conditions, hintMap) if input.size > 2 && conditions.nonEmpty => val reordered = if (SQLConf.get.starSchemaDetection && !SQLConf.get.cboEnabled) { val starJoinPlan = StarSchemaDetection.reorderStarJoins(input, conditions) if (starJoinPlan.nonEmpty) { val rest = input.filterNot(starJoinPlan.contains(_)) - createOrderedJoin(starJoinPlan ++ rest, conditions) + createOrderedJoin(starJoinPlan ++ rest, conditions, hintMap) } else { - createOrderedJoin(input, conditions) + createOrderedJoin(input, conditions, hintMap) } } else { - createOrderedJoin(input, conditions) + createOrderedJoin(input, conditions, hintMap) } if (p.sameOutput(reordered)) { @@ -156,7 +161,7 @@ object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper { } def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case f @ Filter(condition, j @ Join(_, _, RightOuter | LeftOuter | FullOuter, _)) => + case f @ Filter(condition, j @ Join(_, _, RightOuter | LeftOuter | FullOuter, _, _)) => val newJoinType = buildNewJoinType(f, j) if (j.joinType == newJoinType) f else Filter(condition, j.copy(joinType = newJoinType)) } @@ -176,7 +181,7 @@ object PullOutPythonUDFInJoinCondition extends Rule[LogicalPlan] with PredicateH } override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case j @ Join(_, _, joinType, Some(cond)) if hasUnevaluablePythonUDF(cond, j) => + case j @ Join(_, _, joinType, Some(cond), _) if hasUnevaluablePythonUDF(cond, j) => if (!joinType.isInstanceOf[InnerLike] && joinType != LeftSemi) { // The current strategy only support InnerLike and LeftSemi join because for other type, // it breaks SQL semantic if we run the join condition as a filter after join. If we pass diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 34840c6c977a6..e78ed1c3c5d94 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -51,7 +51,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { condition: Option[Expression]): Join = { // Deduplicate conflicting attributes if any. val dedupSubplan = dedupSubqueryOnSelfJoin(outerPlan, subplan, None, condition) - Join(outerPlan, dedupSubplan, joinType, condition) + Join(outerPlan, dedupSubplan, joinType, condition, JoinHint.NONE) } private def dedupSubqueryOnSelfJoin( @@ -116,7 +116,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { val newSub = dedupSubqueryOnSelfJoin(p, sub, Some(values)) val inConditions = values.zip(newSub.output).map(EqualTo.tupled) val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p) - Join(outerPlan, newSub, LeftSemi, joinCond) + Join(outerPlan, newSub, LeftSemi, joinCond, JoinHint.NONE) case (p, Not(InSubquery(values, ListQuery(sub, conditions, _, _)))) => // This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr // Construct the condition. A NULL in one of the conditions is regarded as a positive @@ -142,7 +142,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { // will have the final conditions in the LEFT ANTI as // (A.A1 = B.B1 OR ISNULL(A.A1 = B.B1)) AND (B.B2 = A.A2) AND B.B3 > 1 val finalJoinCond = (nullAwareJoinConds ++ conditions).reduceLeft(And) - Join(outerPlan, newSub, LeftAnti, Option(finalJoinCond)) + Join(outerPlan, newSub, LeftAnti, Option(finalJoinCond), JoinHint.NONE) case (p, predicate) => val (newCond, inputPlan) = rewriteExistentialExpr(Seq(predicate), p) Project(p.output, Filter(newCond.get, inputPlan)) @@ -172,7 +172,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { val newSub = dedupSubqueryOnSelfJoin(newPlan, sub, Some(values)) val inConditions = values.zip(newSub.output).map(EqualTo.tupled) val newConditions = (inConditions ++ conditions).reduceLeftOption(And) - newPlan = Join(newPlan, newSub, ExistenceJoin(exists), newConditions) + newPlan = Join(newPlan, newSub, ExistenceJoin(exists), newConditions, JoinHint.NONE) exists } } @@ -450,7 +450,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { // CASE 1: Subquery guaranteed not to have the COUNT bug Project( currentChild.output :+ origOutput, - Join(currentChild, query, LeftOuter, conditions.reduceOption(And))) + Join(currentChild, query, LeftOuter, conditions.reduceOption(And), JoinHint.NONE)) } else { // Subquery might have the COUNT bug. Add appropriate corrections. val (topPart, havingNode, aggNode) = splitSubquery(query) @@ -477,7 +477,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { aggValRef), origOutput.name)(exprId = origOutput.exprId), Join(currentChild, Project(query.output :+ alwaysTrueExpr, query), - LeftOuter, conditions.reduceOption(And))) + LeftOuter, conditions.reduceOption(And), JoinHint.NONE)) } else { // CASE 3: Subquery with HAVING clause. Pull the HAVING clause above the join. @@ -507,7 +507,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { currentChild.output :+ caseExpr, Join(currentChild, Project(subqueryRoot.output :+ alwaysTrueExpr, subqueryRoot), - LeftOuter, conditions.reduceOption(And))) + LeftOuter, conditions.reduceOption(And), JoinHint.NONE)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 8959f78b656d2..a27c6d3c3671c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -515,7 +515,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging override def visitFromClause(ctx: FromClauseContext): LogicalPlan = withOrigin(ctx) { val from = ctx.relation.asScala.foldLeft(null: LogicalPlan) { (left, relation) => val right = plan(relation.relationPrimary) - val join = right.optionalMap(left)(Join(_, _, Inner, None)) + val join = right.optionalMap(left)(Join(_, _, Inner, None, JoinHint.NONE)) withJoinRelations(join, relation) } if (ctx.pivotClause() != null) { @@ -727,7 +727,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging case None => (baseJoinType, None) } - Join(left, plan(join.right), joinType, condition) + Join(left, plan(join.right), joinType, condition, JoinHint.NONE) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 84be677e438a6..dfc3b2d22129d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.planning +import scala.collection.mutable + import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ @@ -98,12 +100,13 @@ object PhysicalOperation extends PredicateHelper { * value). */ object ExtractEquiJoinKeys extends Logging with PredicateHelper { - /** (joinType, leftKeys, rightKeys, condition, leftChild, rightChild) */ + /** (joinType, leftKeys, rightKeys, condition, leftChild, rightChild, joinHint) */ type ReturnType = - (JoinType, Seq[Expression], Seq[Expression], Option[Expression], LogicalPlan, LogicalPlan) + (JoinType, Seq[Expression], Seq[Expression], + Option[Expression], LogicalPlan, LogicalPlan, JoinHint) def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { - case join @ Join(left, right, joinType, condition) => + case join @ Join(left, right, joinType, condition, hint) => logDebug(s"Considering join on: $condition") // Find equi-join predicates that can be evaluated before the join, and thus can be used // as join keys. @@ -133,7 +136,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { if (joinKeys.nonEmpty) { val (leftKeys, rightKeys) = joinKeys.unzip logDebug(s"leftKeys:$leftKeys | rightKeys:$rightKeys") - Some((joinType, leftKeys, rightKeys, otherPredicates.reduceOption(And), left, right)) + Some((joinType, leftKeys, rightKeys, otherPredicates.reduceOption(And), left, right, hint)) } else { None } @@ -164,25 +167,35 @@ object ExtractFiltersAndInnerJoins extends PredicateHelper { * was involved in an explicit cross join. Also returns the entire list of join conditions for * the left-deep tree. */ - def flattenJoin(plan: LogicalPlan, parentJoinType: InnerLike = Inner) + def flattenJoin( + plan: LogicalPlan, + hintMap: mutable.HashMap[AttributeSet, HintInfo], + parentJoinType: InnerLike = Inner) : (Seq[(LogicalPlan, InnerLike)], Seq[Expression]) = plan match { - case Join(left, right, joinType: InnerLike, cond) => - val (plans, conditions) = flattenJoin(left, joinType) + case Join(left, right, joinType: InnerLike, cond, hint) => + val (plans, conditions) = flattenJoin(left, hintMap, joinType) + hint.leftHint.map(hintMap.put(left.outputSet, _)) + hint.rightHint.map(hintMap.put(right.outputSet, _)) (plans ++ Seq((right, joinType)), conditions ++ cond.toSeq.flatMap(splitConjunctivePredicates)) - case Filter(filterCondition, j @ Join(left, right, _: InnerLike, joinCondition)) => - val (plans, conditions) = flattenJoin(j) + case Filter(filterCondition, j @ Join(_, _, _: InnerLike, _, _)) => + val (plans, conditions) = flattenJoin(j, hintMap) (plans, conditions ++ splitConjunctivePredicates(filterCondition)) case _ => (Seq((plan, parentJoinType)), Seq.empty) } - def unapply(plan: LogicalPlan): Option[(Seq[(LogicalPlan, InnerLike)], Seq[Expression])] + def unapply(plan: LogicalPlan) + : Option[(Seq[(LogicalPlan, InnerLike)], Seq[Expression], Map[AttributeSet, HintInfo])] = plan match { - case f @ Filter(filterCondition, j @ Join(_, _, joinType: InnerLike, _)) => - Some(flattenJoin(f)) - case j @ Join(_, _, joinType, _) => - Some(flattenJoin(j)) + case f @ Filter(filterCondition, j @ Join(_, _, joinType: InnerLike, _, _)) => + val hintMap = new mutable.HashMap[AttributeSet, HintInfo] + val flattened = flattenJoin(f, hintMap) + Some((flattened._1, flattened._2, hintMap.toMap)) + case j @ Join(_, _, joinType, _, _) => + val hintMap = new mutable.HashMap[AttributeSet, HintInfo] + val flattened = flattenJoin(j, hintMap) + Some((flattened._1, flattened._2, hintMap.toMap)) case _ => None } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala index 2c248d74869ce..18baced8f3d61 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala @@ -37,7 +37,6 @@ trait LogicalPlanVisitor[T] { case p: Project => visitProject(p) case p: Repartition => visitRepartition(p) case p: RepartitionByExpression => visitRepartitionByExpr(p) - case p: ResolvedHint => visitHint(p) case p: Sample => visitSample(p) case p: ScriptTransformation => visitScriptTransform(p) case p: Union => visitUnion(p) @@ -61,8 +60,6 @@ trait LogicalPlanVisitor[T] { def visitGlobalLimit(p: GlobalLimit): T - def visitHint(p: ResolvedHint): T - def visitIntersect(p: Intersect): T def visitJoin(p: Join): T diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala index b3a48860aa63b..5a388117a6c0a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala @@ -52,13 +52,11 @@ import org.apache.spark.util.Utils * defaults to the product of children's `sizeInBytes`. * @param rowCount Estimated number of rows. * @param attributeStats Statistics for Attributes. - * @param hints Query hints. */ case class Statistics( sizeInBytes: BigInt, rowCount: Option[BigInt] = None, - attributeStats: AttributeMap[ColumnStat] = AttributeMap(Nil), - hints: HintInfo = HintInfo()) { + attributeStats: AttributeMap[ColumnStat] = AttributeMap(Nil)) { override def toString: String = "Statistics(" + simpleString + ")" @@ -70,8 +68,7 @@ case class Statistics( s"rowCount=${BigDecimal(rowCount.get, new MathContext(3, RoundingMode.HALF_UP)).toString()}" } else { "" - }, - s"hints=$hints" + } ).filter(_.nonEmpty).mkString(", ") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index d8b3a4af4f7bf..639d68f4ecd76 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -288,7 +288,8 @@ case class Join( left: LogicalPlan, right: LogicalPlan, joinType: JoinType, - condition: Option[Expression]) + condition: Option[Expression], + hint: JoinHint) extends BinaryNode with PredicateHelper { override def output: Seq[Attribute] = { @@ -350,6 +351,17 @@ case class Join( case UsingJoin(_, _) => false case _ => resolvedExceptNatural } + + // Ignore hint for canonicalization + protected override def doCanonicalize(): LogicalPlan = + super.doCanonicalize().asInstanceOf[Join].copy(hint = JoinHint.NONE) + + // Do not include an empty join hint in string description + protected override def stringArgs: Iterator[Any] = super.stringArgs.filter { e => + (!e.isInstanceOf[JoinHint] + || e.asInstanceOf[JoinHint].leftHint.isDefined + || e.asInstanceOf[JoinHint].rightHint.isDefined) + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala index cbb626590d1d7..b2ba725e9d44f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala @@ -35,6 +35,7 @@ case class UnresolvedHint(name: String, parameters: Seq[Any], child: LogicalPlan /** * A resolved hint node. The analyzer should convert all [[UnresolvedHint]] into [[ResolvedHint]]. + * This node will be eliminated before optimization starts. */ case class ResolvedHint(child: LogicalPlan, hints: HintInfo = HintInfo()) extends UnaryNode { @@ -44,11 +45,31 @@ case class ResolvedHint(child: LogicalPlan, hints: HintInfo = HintInfo()) override def doCanonicalize(): LogicalPlan = child.canonicalized } +/** + * Hint that is associated with a [[Join]] node, with [[HintInfo]] on its left child and on its + * right child respectively. + */ +case class JoinHint(leftHint: Option[HintInfo], rightHint: Option[HintInfo]) { -case class HintInfo(broadcast: Boolean = false) { + override def toString: String = { + Seq( + leftHint.map("leftHint=" + _), + rightHint.map("rightHint=" + _)) + .filter(_.isDefined).map(_.get).mkString(", ") + } +} - /** Must be called when computing stats for a join operator to reset hints. */ - def resetForJoin(): HintInfo = copy(broadcast = false) +object JoinHint { + val NONE = JoinHint(None, None) +} + +/** + * The hint attributes to be applied on a specific node. + * + * @param broadcast If set to true, it indicates that the broadcast hash join is the preferred join + * strategy and the node with this hint is preferred to be the build side. + */ +case class HintInfo(broadcast: Boolean = false) { override def toString: String = { val hints = scala.collection.mutable.ArrayBuffer.empty[String] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala index 111c594a53e52..eb56ab43ea9d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala @@ -56,8 +56,7 @@ object AggregateEstimation { Some(Statistics( sizeInBytes = getOutputSize(agg.output, outputRows, outputAttrStats), rowCount = Some(outputRows), - attributeStats = outputAttrStats, - hints = childStats.hints)) + attributeStats = outputAttrStats)) } else { None } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala index b6c16079d1984..b8c652dc8f12e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala @@ -47,8 +47,6 @@ object BasicStatsPlanVisitor extends LogicalPlanVisitor[Statistics] { override def visitGlobalLimit(p: GlobalLimit): Statistics = fallback(p) - override def visitHint(p: ResolvedHint): Statistics = fallback(p) - override def visitIntersect(p: Intersect): Statistics = fallback(p) override def visitJoin(p: Join): Statistics = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala index 2543e38a92c0a..19a0d1279cc32 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala @@ -56,7 +56,7 @@ case class JoinEstimation(join: Join) extends Logging { case _ if !rowCountsExist(join.left, join.right) => None - case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, _, _, _) => + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, _, _, _, _) => // 1. Compute join selectivity val joinKeyPairs = extractJoinKeysWithColStats(leftKeys, rightKeys) val (numInnerJoinedRows, keyStatsAfterJoin) = computeCardinalityAndStats(joinKeyPairs) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala index ee43f9126386b..da36db7ae1f5f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala @@ -44,7 +44,7 @@ object SizeInBytesOnlyStatsPlanVisitor extends LogicalPlanVisitor[Statistics] { } // Don't propagate rowCount and attributeStats, since they are not estimated here. - Statistics(sizeInBytes = sizeInBytes, hints = p.child.stats.hints) + Statistics(sizeInBytes = sizeInBytes) } /** @@ -60,8 +60,7 @@ object SizeInBytesOnlyStatsPlanVisitor extends LogicalPlanVisitor[Statistics] { if (p.groupingExpressions.isEmpty) { Statistics( sizeInBytes = EstimationUtils.getOutputSize(p.output, outputRowCount = 1), - rowCount = Some(1), - hints = p.child.stats.hints) + rowCount = Some(1)) } else { visitUnaryNode(p) } @@ -87,19 +86,15 @@ object SizeInBytesOnlyStatsPlanVisitor extends LogicalPlanVisitor[Statistics] { // Don't propagate column stats, because we don't know the distribution after limit Statistics( sizeInBytes = EstimationUtils.getOutputSize(p.output, rowCount, childStats.attributeStats), - rowCount = Some(rowCount), - hints = childStats.hints) + rowCount = Some(rowCount)) } - override def visitHint(p: ResolvedHint): Statistics = p.child.stats.copy(hints = p.hints) - override def visitIntersect(p: Intersect): Statistics = { val leftSize = p.left.stats.sizeInBytes val rightSize = p.right.stats.sizeInBytes val sizeInBytes = if (leftSize < rightSize) leftSize else rightSize Statistics( - sizeInBytes = sizeInBytes, - hints = p.left.stats.hints.resetForJoin()) + sizeInBytes = sizeInBytes) } override def visitJoin(p: Join): Statistics = { @@ -108,10 +103,7 @@ object SizeInBytesOnlyStatsPlanVisitor extends LogicalPlanVisitor[Statistics] { // LeftSemi and LeftAnti won't ever be bigger than left p.left.stats case _ => - // Make sure we don't propagate isBroadcastable in other joins, because - // they could explode the size. - val stats = default(p) - stats.copy(hints = stats.hints.resetForJoin()) + default(p) } } @@ -121,7 +113,7 @@ object SizeInBytesOnlyStatsPlanVisitor extends LogicalPlanVisitor[Statistics] { if (limit == 0) { // sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero // (product of children). - Statistics(sizeInBytes = 1, rowCount = Some(0), hints = childStats.hints) + Statistics(sizeInBytes = 1, rowCount = Some(0)) } else { // The output row count of LocalLimit should be the sum of row counts from each partition. // However, since the number of partitions is not available here, we just use statistics of @@ -147,7 +139,7 @@ object SizeInBytesOnlyStatsPlanVisitor extends LogicalPlanVisitor[Statistics] { } val sampleRows = p.child.stats.rowCount.map(c => EstimationUtils.ceil(BigDecimal(c) * ratio)) // Don't propagate column stats, because we don't know the distribution after a sample operation - Statistics(sizeInBytes, sampleRows, hints = p.child.stats.hints) + Statistics(sizeInBytes, sampleRows) } override def visitScriptTransform(p: ScriptTransformation): Statistics = default(p) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 117e96175e92a..129ce3b1105ee 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -443,7 +443,7 @@ class AnalysisErrorSuite extends AnalysisTest { } test("error test for self-join") { - val join = Join(testRelation, testRelation, Cross, None) + val join = Join(testRelation, testRelation, Cross, None, JoinHint.NONE) val error = intercept[AnalysisException] { SimpleAnalyzer.checkAnalysis(join) } @@ -565,7 +565,8 @@ class AnalysisErrorSuite extends AnalysisTest { LocalRelation(b), Filter(EqualTo(UnresolvedAttribute("a"), c), LocalRelation(c)), LeftOuter, - Option(EqualTo(b, c)))), + Option(EqualTo(b, c)), + JoinHint.NONE)), LocalRelation(a)) assertAnalysisError(plan1, "Accessing outer query column is not allowed in" :: Nil) @@ -575,7 +576,8 @@ class AnalysisErrorSuite extends AnalysisTest { Filter(EqualTo(UnresolvedAttribute("a"), c), LocalRelation(c)), LocalRelation(b), RightOuter, - Option(EqualTo(b, c)))), + Option(EqualTo(b, c)), + JoinHint.NONE)), LocalRelation(a)) assertAnalysisError(plan2, "Accessing outer query column is not allowed in" :: Nil) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index da3ae72c3682a..982948483fa1c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -397,7 +397,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { Join( Project(Seq($"x.key"), SubqueryAlias("x", input)), Project(Seq($"y.key"), SubqueryAlias("y", input)), - Cross, None)) + Cross, None, JoinHint.NONE)) assertAnalysisSuccess(query) } @@ -578,7 +578,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { Seq(UnresolvedAttribute("a")), pythonUdf, output, project) val left = SubqueryAlias("temp0", flatMapGroupsInPandas) val right = SubqueryAlias("temp1", flatMapGroupsInPandas) - val join = Join(left, right, Inner, None) + val join = Join(left, right, Inner, None, JoinHint.NONE) assertAnalysisSuccess( Project(Seq(UnresolvedAttribute("temp0.a"), UnresolvedAttribute("temp1.a")), join)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala index bd66ee5355f45..563e8adf87edc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala @@ -60,7 +60,7 @@ class ResolveHintsSuite extends AnalysisTest { checkAnalysis( UnresolvedHint("MAPJOIN", Seq("table", "table2"), table("table").join(table("table2"))), Join(ResolvedHint(testRelation, HintInfo(broadcast = true)), - ResolvedHint(testRelation2, HintInfo(broadcast = true)), Inner, None), + ResolvedHint(testRelation2, HintInfo(broadcast = true)), Inner, None, JoinHint.NONE), caseSensitive = false) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index 57195d5fda7c5..0cd6e092e2036 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -353,15 +353,15 @@ class ColumnPruningSuite extends PlanTest { Project(Seq($"x.key", $"y.key"), Join( SubqueryAlias("x", input), - ResolvedHint(SubqueryAlias("y", input)), Inner, None)).analyze + SubqueryAlias("y", input), Inner, None, JoinHint.NONE)).analyze val optimized = Optimize.execute(query) val expected = Join( Project(Seq($"x.key"), SubqueryAlias("x", input)), - ResolvedHint(Project(Seq($"y.key"), SubqueryAlias("y", input))), - Inner, None).analyze + Project(Seq($"y.key"), SubqueryAlias("y", input)), + Inner, None, JoinHint.NONE).analyze comparePlans(optimized, expected) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 82a10254d846d..cf4e9fcea2c6d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ @@ -822,19 +821,6 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("broadcast hint") { - val originalQuery = ResolvedHint(testRelation) - .where('a === 2L && 'b + Rand(10).as("rnd") === 3) - - val optimized = Optimize.execute(originalQuery.analyze) - - val correctAnswer = ResolvedHint(testRelation.where('a === 2L)) - .where('b + Rand(10).as("rnd") === 3) - .analyze - - comparePlans(optimized, correctAnswer) - } - test("union") { val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala index 6fe5e619d03ad..9093d7fecb0f7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala @@ -65,7 +65,8 @@ class JoinOptimizationSuite extends PlanTest { def testExtractCheckCross (plan: LogicalPlan, expected: Option[(Seq[(LogicalPlan, InnerLike)], Seq[Expression])]) { - assert(ExtractFiltersAndInnerJoins.unapply(plan) === expected) + assert( + ExtractFiltersAndInnerJoins.unapply(plan) === expected.map(e => (e._1, e._2, Map.empty))) } testExtract(x, None) @@ -124,29 +125,4 @@ class JoinOptimizationSuite extends PlanTest { comparePlans(optimized, queryAnswerPair._2.analyze) } } - - test("broadcasthint sets relation statistics to smallest value") { - val input = LocalRelation('key.int, 'value.string) - - val query = - Project(Seq($"x.key", $"y.key"), - Join( - SubqueryAlias("x", input), - ResolvedHint(SubqueryAlias("y", input)), Cross, None)).analyze - - val optimized = Optimize.execute(query) - - val expected = - Join( - Project(Seq($"x.key"), SubqueryAlias("x", input)), - ResolvedHint(Project(Seq($"y.key"), SubqueryAlias("y", input))), - Cross, None).analyze - - comparePlans(optimized, expected) - - val broadcastChildren = optimized.collect { - case Join(_, r, _, _) if r.stats.sizeInBytes == 1 => r - } - assert(broadcastChildren.size == 1) - } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala index c94a8b9e318f6..0dee846205868 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala @@ -31,6 +31,8 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { object Optimize extends RuleExecutor[LogicalPlan] { val batches = + Batch("Resolve Hints", Once, + EliminateResolvedHint) :: Batch("Operator Optimizations", FixedPoint(100), CombineFilters, PushDownPredicate, @@ -42,6 +44,12 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { CostBasedJoinReorder) :: Nil } + object ResolveHints extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Resolve Hints", Once, + EliminateResolvedHint) :: Nil + } + var originalConfCBOEnabled = false var originalConfJoinReorderEnabled = false @@ -284,12 +292,85 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { assertEqualPlans(originalPlan, bestPlan) } + test("hints preservation") { + // Apply hints if we find an equivalent node in the new plan, otherwise discard them. + val originalPlan = + t1.join(t2.hint("broadcast")).hint("broadcast").join(t4.join(t3).hint("broadcast")) + .where((nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")) && + (nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) && + (nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))) + + val bestPlan = + t1.join(t2.hint("broadcast"), Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) + .hint("broadcast") + .join( + t4.join(t3, Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))) + .hint("broadcast"), + Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2"))) + + assertEqualPlans(originalPlan, bestPlan) + + val originalPlan2 = + t1.join(t2).hint("broadcast").join(t3).hint("broadcast").join(t4.hint("broadcast")) + .where((nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")) && + (nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) && + (nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))) + + val bestPlan2 = + t1.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) + .hint("broadcast") + .join( + t4.hint("broadcast") + .join(t3, Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))), + Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2"))) + .select(outputsOf(t1, t2, t3, t4): _*) + + assertEqualPlans(originalPlan2, bestPlan2) + + val originalPlan3 = + t1.join(t4).hint("broadcast") + .join(t2.hint("broadcast")).hint("broadcast") + .join(t3.hint("broadcast")) + .where((nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")) && + (nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) && + (nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))) + + val bestPlan3 = + t1.join(t2.hint("broadcast"), Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) + .join( + t4.join(t3.hint("broadcast"), + Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))), + Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2"))) + .select(outputsOf(t1, t4, t2, t3): _*) + + assertEqualPlans(originalPlan3, bestPlan3) + + val originalPlan4 = + t2.hint("broadcast") + .join(t4).hint("broadcast") + .join(t3.hint("broadcast")).hint("broadcast") + .join(t1) + .where((nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")) && + (nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) && + (nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))) + + val bestPlan4 = + t1.join(t2.hint("broadcast"), Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) + .join( + t4.join(t3.hint("broadcast"), + Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))), + Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2"))) + .select(outputsOf(t2, t4, t3, t1): _*) + + assertEqualPlans(originalPlan4, bestPlan4) + } + private def assertEqualPlans( originalPlan: LogicalPlan, groundTruthBestPlan: LogicalPlan): Unit = { val analyzed = originalPlan.analyze val optimized = Optimize.execute(analyzed) - val expected = groundTruthBestPlan.analyze + val expected = ResolveHints.execute(groundTruthBestPlan.analyze) assert(analyzed.sameOutput(expected)) // if this fails, the expected plan itself is incorrect assert(analyzed.sameOutput(optimized)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala index c8e15c7da763e..6d1af12e68b23 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala @@ -48,7 +48,7 @@ class ReplaceOperatorSuite extends PlanTest { val correctAnswer = Aggregate(table1.output, table1.output, - Join(table1, table2, LeftSemi, Option('a <=> 'c && 'b <=> 'd))).analyze + Join(table1, table2, LeftSemi, Option('a <=> 'c && 'b <=> 'd), JoinHint.NONE)).analyze comparePlans(optimized, correctAnswer) } @@ -160,7 +160,7 @@ class ReplaceOperatorSuite extends PlanTest { val correctAnswer = Aggregate(table1.output, table1.output, - Join(table1, table2, LeftAnti, Option('a <=> 'c && 'b <=> 'd))).analyze + Join(table1, table2, LeftAnti, Option('a <=> 'c && 'b <=> 'd), JoinHint.NONE)).analyze comparePlans(optimized, correctAnswer) } @@ -175,7 +175,7 @@ class ReplaceOperatorSuite extends PlanTest { val correctAnswer = Aggregate(left.output, right.output, - Join(left, right, LeftAnti, Option($"left.a" <=> $"right.a"))).analyze + Join(left, right, LeftAnti, Option($"left.a" <=> $"right.a"), JoinHint.NONE)).analyze comparePlans(optimized, correctAnswer) } @@ -248,7 +248,7 @@ class ReplaceOperatorSuite extends PlanTest { val condition = basePlan.output.zip(otherPlan.output).map { case (a1, a2) => a1 <=> a2 }.reduce( _ && _) val correctAnswer = Aggregate(basePlan.output, otherPlan.output, - Join(basePlan, otherPlan, LeftAnti, Option(condition))).analyze + Join(basePlan, otherPlan, LeftAnti, Option(condition), JoinHint.NONE)).analyze comparePlans(result, correctAnswer) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 3081ff935f043..5394732f41f2d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -99,11 +99,11 @@ trait PlanTestBase extends PredicateHelper with SQLHelper { self: Suite => .reduce(And), child) case sample: Sample => sample.copy(seed = 0L) - case Join(left, right, joinType, condition) if condition.isDefined => + case Join(left, right, joinType, condition, hint) if condition.isDefined => val newCondition = splitConjunctivePredicates(condition.get).map(rewriteEqual).sortBy(_.hashCode()) .reduce(And) - Join(left, right, joinType, Some(newCondition)) + Join(left, right, joinType, Some(newCondition), hint) } } @@ -165,8 +165,10 @@ trait PlanTestBase extends PredicateHelper with SQLHelper { self: Suite => private def sameJoinPlan(plan1: LogicalPlan, plan2: LogicalPlan): Boolean = { (plan1, plan2) match { case (j1: Join, j2: Join) => - (sameJoinPlan(j1.left, j2.left) && sameJoinPlan(j1.right, j2.right)) || - (sameJoinPlan(j1.left, j2.right) && sameJoinPlan(j1.right, j2.left)) + (sameJoinPlan(j1.left, j2.left) && sameJoinPlan(j1.right, j2.right) + && j1.hint.leftHint == j2.hint.leftHint && j1.hint.rightHint == j2.hint.rightHint) || + (sameJoinPlan(j1.left, j2.right) && sameJoinPlan(j1.right, j2.left) + && j1.hint.leftHint == j2.hint.rightHint && j1.hint.rightHint == j2.hint.leftHint) case (p1: Project, p2: Project) => p1.projectList == p2.projectList && sameJoinPlan(p1.child, p2.child) case _ => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala index 7c8ed78a49116..fbaaf807af5d6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala @@ -20,7 +20,9 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, ResolvedHint, Union} +import org.apache.spark.sql.catalyst.optimizer.EliminateResolvedHint +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.util._ /** @@ -30,6 +32,10 @@ class SameResultSuite extends SparkFunSuite { val testRelation = LocalRelation('a.int, 'b.int, 'c.int) val testRelation2 = LocalRelation('a.int, 'b.int, 'c.int) + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = Batch("EliminateResolvedHint", Once, EliminateResolvedHint) :: Nil + } + def assertSameResult(a: LogicalPlan, b: LogicalPlan, result: Boolean = true): Unit = { val aAnalyzed = a.analyze val bAnalyzed = b.analyze @@ -72,4 +78,12 @@ class SameResultSuite extends SparkFunSuite { val df2 = testRelation.join(testRelation) assertSameResult(df1, df2) } + + test("join hint") { + val df1 = testRelation.join(testRelation.hint("broadcast")) + val df2 = testRelation.join(testRelation) + val df1Optimized = Optimize.execute(df1.analyze) + val df2Optimized = Optimize.execute(df2.analyze) + assertSameResult(df1Optimized, df2Optimized) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala index 953094cb0dd52..16a5c2d3001a7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala @@ -38,24 +38,6 @@ class BasicStatsEstimationSuite extends PlanTest with StatsEstimationTestBase { // row count * (overhead + column size) size = Some(10 * (8 + 4))) - test("BroadcastHint estimation") { - val filter = Filter(Literal(true), plan) - val filterStatsCboOn = Statistics(sizeInBytes = 10 * (8 +4), - rowCount = Some(10), attributeStats = AttributeMap(Seq(attribute -> colStat))) - val filterStatsCboOff = Statistics(sizeInBytes = 10 * (8 +4)) - checkStats( - filter, - expectedStatsCboOn = filterStatsCboOn, - expectedStatsCboOff = filterStatsCboOff) - - val broadcastHint = ResolvedHint(filter, HintInfo(broadcast = true)) - checkStats( - broadcastHint, - expectedStatsCboOn = filterStatsCboOn.copy(hints = HintInfo(broadcast = true)), - expectedStatsCboOff = filterStatsCboOff.copy(hints = HintInfo(broadcast = true)) - ) - } - test("range") { val range = Range(1, 5, 1, None) val rangeStats = Statistics(sizeInBytes = 4 * 8) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index b0a47e7835129..1cf888519077a 100755 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -528,7 +528,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { rowCount = 30, attributeStats = AttributeMap(Seq(attrIntLargerRange -> colStatIntLargerRange))) val nonLeafChild = Join(largerTable, smallerTable, LeftOuter, - Some(EqualTo(attrIntLargerRange, attrInt))) + Some(EqualTo(attrIntLargerRange, attrInt)), JoinHint.NONE) Seq(IsNull(attrIntLargerRange), IsNotNull(attrIntLargerRange)).foreach { predicate => validateEstimatedStats( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala index 12c0a7be21292..6c5a2b247fc23 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala @@ -79,8 +79,8 @@ class JoinEstimationSuite extends StatsEstimationTestBase { val c1 = generateJoinChild(col1, leftHistogram, expectedMin, expectedMax) val c2 = generateJoinChild(col2, rightHistogram, expectedMin, expectedMax) - val c1JoinC2 = Join(c1, c2, Inner, Some(EqualTo(col1, col2))) - val c2JoinC1 = Join(c2, c1, Inner, Some(EqualTo(col2, col1))) + val c1JoinC2 = Join(c1, c2, Inner, Some(EqualTo(col1, col2)), JoinHint.NONE) + val c2JoinC1 = Join(c2, c1, Inner, Some(EqualTo(col2, col1)), JoinHint.NONE) val expectedStatsAfterJoin = Statistics( sizeInBytes = expectedRows * (8 + 2 * 4), rowCount = Some(expectedRows), @@ -284,7 +284,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { test("cross join") { // table1 (key-1-5 int, key-5-9 int): (1, 9), (2, 8), (3, 7), (4, 6), (5, 5) // table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4) - val join = Join(table1, table2, Cross, None) + val join = Join(table1, table2, Cross, None, JoinHint.NONE) val expectedStats = Statistics( sizeInBytes = 5 * 3 * (8 + 4 * 4), rowCount = Some(5 * 3), @@ -299,7 +299,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { // table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4) // key-5-9 and key-2-4 are disjoint val join = Join(table1, table2, Inner, - Some(EqualTo(nameToAttr("key-5-9"), nameToAttr("key-2-4")))) + Some(EqualTo(nameToAttr("key-5-9"), nameToAttr("key-2-4"))), JoinHint.NONE) val expectedStats = Statistics( sizeInBytes = 1, rowCount = Some(0), @@ -312,7 +312,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { // table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4) // key-5-9 and key-2-4 are disjoint val join = Join(table1, table2, LeftOuter, - Some(EqualTo(nameToAttr("key-5-9"), nameToAttr("key-2-4")))) + Some(EqualTo(nameToAttr("key-5-9"), nameToAttr("key-2-4"))), JoinHint.NONE) val expectedStats = Statistics( sizeInBytes = 5 * (8 + 4 * 4), rowCount = Some(5), @@ -328,7 +328,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { // table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4) // key-5-9 and key-2-4 are disjoint val join = Join(table1, table2, RightOuter, - Some(EqualTo(nameToAttr("key-5-9"), nameToAttr("key-2-4")))) + Some(EqualTo(nameToAttr("key-5-9"), nameToAttr("key-2-4"))), JoinHint.NONE) val expectedStats = Statistics( sizeInBytes = 3 * (8 + 4 * 4), rowCount = Some(3), @@ -344,7 +344,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { // table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4) // key-5-9 and key-2-4 are disjoint val join = Join(table1, table2, FullOuter, - Some(EqualTo(nameToAttr("key-5-9"), nameToAttr("key-2-4")))) + Some(EqualTo(nameToAttr("key-5-9"), nameToAttr("key-2-4"))), JoinHint.NONE) val expectedStats = Statistics( sizeInBytes = (5 + 3) * (8 + 4 * 4), rowCount = Some(5 + 3), @@ -361,7 +361,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { // table1 (key-1-5 int, key-5-9 int): (1, 9), (2, 8), (3, 7), (4, 6), (5, 5) // table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4) val join = Join(table1, table2, Inner, - Some(EqualTo(nameToAttr("key-1-5"), nameToAttr("key-1-2")))) + Some(EqualTo(nameToAttr("key-1-5"), nameToAttr("key-1-2"))), JoinHint.NONE) // Update column stats for equi-join keys (key-1-5 and key-1-2). val joinedColStat = ColumnStat(distinctCount = Some(2), min = Some(1), max = Some(2), nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) @@ -383,7 +383,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { // table3 (key-1-2 int, key-2-3 int): (1, 2), (2, 3) val join = Join(table2, table3, Inner, Some( And(EqualTo(nameToAttr("key-1-2"), nameToAttr("key-1-2")), - EqualTo(nameToAttr("key-2-4"), nameToAttr("key-2-3"))))) + EqualTo(nameToAttr("key-2-4"), nameToAttr("key-2-3")))), JoinHint.NONE) // Update column stats for join keys. val joinedColStat1 = ColumnStat(distinctCount = Some(2), min = Some(1), max = Some(2), @@ -404,7 +404,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { // table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4) // table3 (key-1-2 int, key-2-3 int): (1, 2), (2, 3) val join = Join(table3, table2, LeftOuter, - Some(EqualTo(nameToAttr("key-2-3"), nameToAttr("key-2-4")))) + Some(EqualTo(nameToAttr("key-2-3"), nameToAttr("key-2-4"))), JoinHint.NONE) val joinedColStat = ColumnStat(distinctCount = Some(2), min = Some(2), max = Some(3), nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) @@ -422,7 +422,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { // table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4) // table3 (key-1-2 int, key-2-3 int): (1, 2), (2, 3) val join = Join(table2, table3, RightOuter, - Some(EqualTo(nameToAttr("key-2-4"), nameToAttr("key-2-3")))) + Some(EqualTo(nameToAttr("key-2-4"), nameToAttr("key-2-3"))), JoinHint.NONE) val joinedColStat = ColumnStat(distinctCount = Some(2), min = Some(2), max = Some(3), nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) @@ -440,7 +440,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { // table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4) // table3 (key-1-2 int, key-2-3 int): (1, 2), (2, 3) val join = Join(table2, table3, FullOuter, - Some(EqualTo(nameToAttr("key-2-4"), nameToAttr("key-2-3")))) + Some(EqualTo(nameToAttr("key-2-4"), nameToAttr("key-2-3"))), JoinHint.NONE) val expectedStats = Statistics( sizeInBytes = 3 * (8 + 4 * 4), @@ -456,7 +456,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { // table3 (key-1-2 int, key-2-3 int): (1, 2), (2, 3) Seq(LeftSemi, LeftAnti).foreach { jt => val join = Join(table2, table3, jt, - Some(EqualTo(nameToAttr("key-2-4"), nameToAttr("key-2-3")))) + Some(EqualTo(nameToAttr("key-2-4"), nameToAttr("key-2-3"))), JoinHint.NONE) // For now we just propagate the statistics from left side for left semi/anti join. val expectedStats = Statistics( sizeInBytes = 3 * (8 + 4 * 2), @@ -525,7 +525,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { withClue(s"For data type ${key1.dataType}") { // All values in two tables are the same, so column stats after join are also the same. val join = Join(Project(Seq(key1), table1), Project(Seq(key2), table2), Inner, - Some(EqualTo(key1, key2))) + Some(EqualTo(key1, key2)), JoinHint.NONE) val expectedStats = Statistics( sizeInBytes = 1 * (8 + 2 * getColSize(key1, columnInfo1(key1))), rowCount = Some(1), @@ -543,7 +543,8 @@ class JoinEstimationSuite extends StatsEstimationTestBase { outputList = Seq(nullColumn), rowCount = 1, attributeStats = AttributeMap(Seq(nullColumn -> nullColStat))) - val join = Join(table1, nullTable, Inner, Some(EqualTo(nameToAttr("key-1-5"), nullColumn))) + val join = Join(table1, nullTable, Inner, + Some(EqualTo(nameToAttr("key-1-5"), nullColumn)), JoinHint.NONE) val expectedStats = Statistics( sizeInBytes = 1, rowCount = Some(0), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index a664c7338badb..44cada086489a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -862,7 +862,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ def join(right: Dataset[_]): DataFrame = withPlan { - Join(logicalPlan, right.logicalPlan, joinType = Inner, None) + Join(logicalPlan, right.logicalPlan, joinType = Inner, None, JoinHint.NONE) } /** @@ -940,7 +940,7 @@ class Dataset[T] private[sql]( // Analyze the self join. The assumption is that the analyzer will disambiguate left vs right // by creating a new instance for one of the branch. val joined = sparkSession.sessionState.executePlan( - Join(logicalPlan, right.logicalPlan, joinType = JoinType(joinType), None)) + Join(logicalPlan, right.logicalPlan, joinType = JoinType(joinType), None, JoinHint.NONE)) .analyzed.asInstanceOf[Join] withPlan { @@ -948,7 +948,8 @@ class Dataset[T] private[sql]( joined.left, joined.right, UsingJoin(JoinType(joinType), usingColumns), - None) + None, + JoinHint.NONE) } } @@ -1001,7 +1002,7 @@ class Dataset[T] private[sql]( // Trigger analysis so in the case of self-join, the analyzer will clone the plan. // After the cloning, left and right side will have distinct expression ids. val plan = withPlan( - Join(logicalPlan, right.logicalPlan, JoinType(joinType), Some(joinExprs.expr))) + Join(logicalPlan, right.logicalPlan, JoinType(joinType), Some(joinExprs.expr), JoinHint.NONE)) .queryExecution.analyzed.asInstanceOf[Join] // If auto self join alias is disabled, return the plan. @@ -1048,7 +1049,7 @@ class Dataset[T] private[sql]( * @since 2.1.0 */ def crossJoin(right: Dataset[_]): DataFrame = withPlan { - Join(logicalPlan, right.logicalPlan, joinType = Cross, None) + Join(logicalPlan, right.logicalPlan, joinType = Cross, None, JoinHint.NONE) } /** @@ -1083,7 +1084,8 @@ class Dataset[T] private[sql]( this.logicalPlan, other.logicalPlan, JoinType(joinType), - Some(condition.expr))).analyzed.asInstanceOf[Join] + Some(condition.expr), + JoinHint.NONE)).analyzed.asInstanceOf[Join] if (joined.joinType == LeftSemi || joined.joinType == LeftAnti) { throw new AnalysisException("Invalid join type in joinWith: " + joined.joinType.sql) @@ -1135,7 +1137,7 @@ class Dataset[T] private[sql]( implicit val tuple2Encoder: Encoder[(T, U)] = ExpressionEncoder.tuple(this.exprEnc, other.exprEnc) - withTypedPlan(Join(left, right, joined.joinType, Some(conditionExpr))) + withTypedPlan(Join(left, right, joined.joinType, Some(conditionExpr), JoinHint.NONE)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index dbc6db62bd820..b7cc373b2df12 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -208,17 +208,17 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } - private def canBroadcastByHints(joinType: JoinType, left: LogicalPlan, right: LogicalPlan) - : Boolean = { - val buildLeft = canBuildLeft(joinType) && left.stats.hints.broadcast - val buildRight = canBuildRight(joinType) && right.stats.hints.broadcast + private def canBroadcastByHints( + joinType: JoinType, left: LogicalPlan, right: LogicalPlan, hint: JoinHint): Boolean = { + val buildLeft = canBuildLeft(joinType) && hint.leftHint.exists(_.broadcast) + val buildRight = canBuildRight(joinType) && hint.rightHint.exists(_.broadcast) buildLeft || buildRight } - private def broadcastSideByHints(joinType: JoinType, left: LogicalPlan, right: LogicalPlan) - : BuildSide = { - val buildLeft = canBuildLeft(joinType) && left.stats.hints.broadcast - val buildRight = canBuildRight(joinType) && right.stats.hints.broadcast + private def broadcastSideByHints( + joinType: JoinType, left: LogicalPlan, right: LogicalPlan, hint: JoinHint): BuildSide = { + val buildLeft = canBuildLeft(joinType) && hint.leftHint.exists(_.broadcast) + val buildRight = canBuildRight(joinType) && hint.rightHint.exists(_.broadcast) broadcastSide(buildLeft, buildRight, left, right) } @@ -241,14 +241,14 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // --- BroadcastHashJoin -------------------------------------------------------------------- // broadcast hints were specified - case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) - if canBroadcastByHints(joinType, left, right) => - val buildSide = broadcastSideByHints(joinType, left, right) + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, hint) + if canBroadcastByHints(joinType, left, right, hint) => + val buildSide = broadcastSideByHints(joinType, left, right, hint) Seq(joins.BroadcastHashJoinExec( leftKeys, rightKeys, joinType, buildSide, condition, planLater(left), planLater(right))) // broadcast hints were not specified, so need to infer it from size and configuration. - case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, _) if canBroadcastBySizes(joinType, left, right) => val buildSide = broadcastSideBySizes(joinType, left, right) Seq(joins.BroadcastHashJoinExec( @@ -256,14 +256,14 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // --- ShuffledHashJoin --------------------------------------------------------------------- - case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, _) if !conf.preferSortMergeJoin && canBuildRight(joinType) && canBuildLocalHashMap(right) && muchSmaller(right, left) || !RowOrdering.isOrderable(leftKeys) => Seq(joins.ShuffledHashJoinExec( leftKeys, rightKeys, joinType, BuildRight, condition, planLater(left), planLater(right))) - case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, _) if !conf.preferSortMergeJoin && canBuildLeft(joinType) && canBuildLocalHashMap(left) && muchSmaller(left, right) || !RowOrdering.isOrderable(leftKeys) => @@ -272,7 +272,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // --- SortMergeJoin ------------------------------------------------------------ - case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, _) if RowOrdering.isOrderable(leftKeys) => joins.SortMergeJoinExec( leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil @@ -280,25 +280,25 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // --- Without joining keys ------------------------------------------------------------ // Pick BroadcastNestedLoopJoin if one side could be broadcast - case j @ logical.Join(left, right, joinType, condition) - if canBroadcastByHints(joinType, left, right) => - val buildSide = broadcastSideByHints(joinType, left, right) + case j @ logical.Join(left, right, joinType, condition, hint) + if canBroadcastByHints(joinType, left, right, hint) => + val buildSide = broadcastSideByHints(joinType, left, right, hint) joins.BroadcastNestedLoopJoinExec( planLater(left), planLater(right), buildSide, joinType, condition) :: Nil - case j @ logical.Join(left, right, joinType, condition) + case j @ logical.Join(left, right, joinType, condition, _) if canBroadcastBySizes(joinType, left, right) => val buildSide = broadcastSideBySizes(joinType, left, right) joins.BroadcastNestedLoopJoinExec( planLater(left), planLater(right), buildSide, joinType, condition) :: Nil // Pick CartesianProduct for InnerJoin - case logical.Join(left, right, _: InnerLike, condition) => + case logical.Join(left, right, _: InnerLike, condition, _) => joins.CartesianProductExec(planLater(left), planLater(right), condition) :: Nil - case logical.Join(left, right, joinType, condition) => + case logical.Join(left, right, joinType, condition, hint) => val buildSide = broadcastSide( - left.stats.hints.broadcast, right.stats.hints.broadcast, left, right) + hint.leftHint.exists(_.broadcast), hint.rightHint.exists(_.broadcast), left, right) // This join could be very slow or OOM joins.BroadcastNestedLoopJoinExec( planLater(left), planLater(right), buildSide, joinType, condition) :: Nil @@ -380,13 +380,13 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object StreamingJoinStrategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = { plan match { - case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, _) if left.isStreaming && right.isStreaming => new StreamingSymmetricHashJoinExec( leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil - case Join(left, right, _, _) if left.isStreaming && right.isStreaming => + case Join(left, right, _, _, _) if left.isStreaming && right.isStreaming => throw new AnalysisException( "Stream-stream join without equality predicate is not supported", plan = Some(plan)) @@ -561,6 +561,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { throw new IllegalStateException( "logical except (all) operator should have been replaced by union, aggregate" + " and generate operators in the optimizer") + case logical.ResolvedHint(child, hints) => + throw new IllegalStateException( + "ResolvedHint operator should have been replaced by join hint in the optimizer") case logical.DeserializeToObject(deserializer, objAttr, child) => execution.DeserializeToObjectExec(deserializer, objAttr, planLater(child)) :: Nil @@ -632,7 +635,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case ExternalRDD(outputObjAttr, rdd) => ExternalRDDScanExec(outputObjAttr, rdd) :: Nil case r: LogicalRDD => RDDScanExec(r.output, r.rdd, "ExistingRDD", r.outputPartitioning, r.outputOrdering) :: Nil - case h: ResolvedHint => planLater(h.child) :: Nil case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 4109d9994dd8f..41f406d6c2993 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, LogicalPlan, Statistics} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.storage.StorageLevel @@ -184,12 +184,7 @@ case class InMemoryRelation( override def computeStats(): Statistics = { if (cacheBuilder.sizeInBytesStats.value == 0L) { // Underlying columnar RDD hasn't been materialized, use the stats from the plan to cache. - // Note that we should drop the hint info here. We may cache a plan whose root node is a hint - // node. When we lookup the cache with a semantically same plan without hint info, the plan - // returned by cache lookup should not have hint info. If we lookup the cache with a - // semantically same plan with a different hint info, `CacheManager.useCachedData` will take - // care of it and retain the hint info in the lookup input plan. - statsOfPlanToCache.copy(hints = HintInfo()) + statsOfPlanToCache } else { Statistics(sizeInBytes = cacheBuilder.sizeInBytesStats.value.longValue) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 6e805c4f3c39a..2141be4d680f9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -27,9 +27,11 @@ import org.apache.spark.executor.DataReadMethod.DataReadMethod import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.SubqueryExpression +import org.apache.spark.sql.catalyst.plans.logical.Join import org.apache.spark.sql.execution.{RDDScanExec, SparkPlan} import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec +import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.storage.{RDDBlockId, StorageLevel} @@ -925,4 +927,23 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext } } } + + test("Cache should respect the broadcast hint") { + val df = broadcast(spark.range(1000)).cache() + val df2 = spark.range(1000).cache() + df.count() + df2.count() + + // Test the broadcast hint. + val joinPlan = df.join(df2, "id").queryExecution.optimizedPlan + val hint = joinPlan.collect { + case Join(_, _, _, _, hint) => hint + } + assert(hint.size == 1) + assert(hint(0).leftHint.get.broadcast) + assert(hint(0).rightHint.isEmpty) + + // Clean-up + df.unpersist() + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index c9f41ab1c0179..a4a3e2a62d1a5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -198,7 +198,7 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { // outer -> left val outerJoin2Left = df.join(df2, $"a.int" === $"b.int", "outer").where($"a.int" >= 3) assert(outerJoin2Left.queryExecution.optimizedPlan.collect { - case j @ Join(_, _, LeftOuter, _) => j }.size === 1) + case j @ Join(_, _, LeftOuter, _, _) => j }.size === 1) checkAnswer( outerJoin2Left, Row(3, 4, "3", null, null, null) :: Nil) @@ -206,7 +206,7 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { // outer -> right val outerJoin2Right = df.join(df2, $"a.int" === $"b.int", "outer").where($"b.int" >= 3) assert(outerJoin2Right.queryExecution.optimizedPlan.collect { - case j @ Join(_, _, RightOuter, _) => j }.size === 1) + case j @ Join(_, _, RightOuter, _, _) => j }.size === 1) checkAnswer( outerJoin2Right, Row(null, null, null, 5, 6, "5") :: Nil) @@ -215,7 +215,7 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { val outerJoin2Inner = df.join(df2, $"a.int" === $"b.int", "outer"). where($"a.int" === 1 && $"b.int2" === 3) assert(outerJoin2Inner.queryExecution.optimizedPlan.collect { - case j @ Join(_, _, Inner, _) => j }.size === 1) + case j @ Join(_, _, Inner, _, _) => j }.size === 1) checkAnswer( outerJoin2Inner, Row(1, 2, "1", 1, 3, "1") :: Nil) @@ -223,7 +223,7 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { // right -> inner val rightJoin2Inner = df.join(df2, $"a.int" === $"b.int", "right").where($"a.int" > 0) assert(rightJoin2Inner.queryExecution.optimizedPlan.collect { - case j @ Join(_, _, Inner, _) => j }.size === 1) + case j @ Join(_, _, Inner, _, _) => j }.size === 1) checkAnswer( rightJoin2Inner, Row(1, 2, "1", 1, 3, "1") :: Nil) @@ -231,7 +231,7 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { // left -> inner val leftJoin2Inner = df.join(df2, $"a.int" === $"b.int", "left").where($"b.int2" > 0) assert(leftJoin2Inner.queryExecution.optimizedPlan.collect { - case j @ Join(_, _, Inner, _) => j }.size === 1) + case j @ Join(_, _, Inner, _, _) => j }.size === 1) checkAnswer( leftJoin2Inner, Row(1, 2, "1", 1, 3, "1") :: Nil) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala new file mode 100644 index 0000000000000..3652895ff43d8 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.test.SharedSQLContext + +class JoinHintSuite extends PlanTest with SharedSQLContext { + import testImplicits._ + + lazy val df = spark.range(10) + lazy val df1 = df.selectExpr("id as a1", "id as a2") + lazy val df2 = df.selectExpr("id as b1", "id as b2") + lazy val df3 = df.selectExpr("id as c1", "id as c2") + + def verifyJoinHint(df: DataFrame, expectedHints: Seq[JoinHint]): Unit = { + val optimized = df.queryExecution.optimizedPlan + val joinHints = optimized collect { + case Join(_, _, _, _, hint) => hint + case _: ResolvedHint => fail("ResolvedHint should not appear after optimize.") + } + assert(joinHints == expectedHints) + } + + test("single join") { + verifyJoinHint( + df.hint("broadcast").join(df, "id"), + JoinHint( + Some(HintInfo(broadcast = true)), + None) :: Nil + ) + verifyJoinHint( + df.join(df.hint("broadcast"), "id"), + JoinHint( + None, + Some(HintInfo(broadcast = true))) :: Nil + ) + } + + test("multiple joins") { + verifyJoinHint( + df1.join(df2.hint("broadcast").join(df3, 'b1 === 'c1).hint("broadcast"), 'a1 === 'c1), + JoinHint( + None, + Some(HintInfo(broadcast = true))) :: + JoinHint( + Some(HintInfo(broadcast = true)), + None) :: Nil + ) + verifyJoinHint( + df1.hint("broadcast").join(df2, 'a1 === 'b1).hint("broadcast").join(df3, 'a1 === 'c1), + JoinHint( + Some(HintInfo(broadcast = true)), + None) :: + JoinHint( + Some(HintInfo(broadcast = true)), + None) :: Nil + ) + } + + test("hint scope") { + withTempView("a", "b", "c") { + df1.createOrReplaceTempView("a") + df2.createOrReplaceTempView("b") + verifyJoinHint( + sql( + """ + |select /*+ broadcast(a, b)*/ * from ( + | select /*+ broadcast(b)*/ * from a join b on a.a1 = b.b1 + |) a join ( + | select /*+ broadcast(a)*/ * from a join b on a.a1 = b.b1 + |) b on a.a1 = b.b1 + """.stripMargin), + JoinHint( + Some(HintInfo(broadcast = true)), + Some(HintInfo(broadcast = true))) :: + JoinHint( + None, + Some(HintInfo(broadcast = true))) :: + JoinHint( + Some(HintInfo(broadcast = true)), + None) :: Nil + ) + } + } + + test("hint preserved after join reorder") { + withTempView("a", "b", "c") { + df1.createOrReplaceTempView("a") + df2.createOrReplaceTempView("b") + df3.createOrReplaceTempView("c") + verifyJoinHint( + sql("select /*+ broadcast(a, c)*/ * from a, b, c " + + "where a.a1 = b.b1 and b.b1 = c.c1"), + JoinHint( + None, + Some(HintInfo(broadcast = true))) :: + JoinHint( + Some(HintInfo(broadcast = true)), + None):: Nil + ) + verifyJoinHint( + sql("select /*+ broadcast(a, c)*/ * from a, c, b " + + "where a.a1 = b.b1 and b.b1 = c.c1"), + JoinHint( + None, + Some(HintInfo(broadcast = true))) :: + JoinHint( + Some(HintInfo(broadcast = true)), + None):: Nil + ) + verifyJoinHint( + sql("select /*+ broadcast(b, c)*/ * from a, c, b " + + "where a.a1 = b.b1 and b.b1 = c.c1"), + JoinHint( + None, + Some(HintInfo(broadcast = true))) :: + JoinHint( + None, + Some(HintInfo(broadcast = true))):: Nil + ) + + verifyJoinHint( + df1.join(df2, 'a1 === 'b1 && 'a1 > 5).hint("broadcast") + .join(df3, 'b1 === 'c1 && 'a1 < 10), + JoinHint( + Some(HintInfo(broadcast = true)), + None) :: + JoinHint.NONE:: Nil + ) + + verifyJoinHint( + df1.join(df2, 'a1 === 'b1 && 'a1 > 5).hint("broadcast") + .join(df3, 'b1 === 'c1 && 'a1 < 10) + .join(df, 'b1 === 'id), + JoinHint.NONE :: + JoinHint( + Some(HintInfo(broadcast = true)), + None) :: + JoinHint.NONE:: Nil + ) + } + } + + test("intersect/except") { + val dfSub = spark.range(2) + verifyJoinHint( + df.hint("broadcast").except(dfSub).join(df, "id"), + JoinHint( + Some(HintInfo(broadcast = true)), + None) :: + JoinHint.NONE :: Nil + ) + verifyJoinHint( + df.join(df.hint("broadcast").intersect(dfSub), "id"), + JoinHint( + None, + Some(HintInfo(broadcast = true))) :: + JoinHint.NONE :: Nil + ) + } + + test("hint merge") { + verifyJoinHint( + df.hint("broadcast").filter('id > 2).hint("broadcast").join(df, "id"), + JoinHint( + Some(HintInfo(broadcast = true)), + None) :: Nil + ) + verifyJoinHint( + df.join(df.hint("broadcast").limit(2).hint("broadcast"), "id"), + JoinHint( + None, + Some(HintInfo(broadcast = true))) :: Nil + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index 02dc32d5f90ba..99842680cedfe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -237,8 +237,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared ) numbers.foreach { case (input, (expectedSize, expectedRows)) => val stats = Statistics(sizeInBytes = input, rowCount = Some(input)) - val expectedString = s"sizeInBytes=$expectedSize, rowCount=$expectedRows," + - s" hints=none" + val expectedString = s"sizeInBytes=$expectedSize, rowCount=$expectedRows" assert(stats.simpleString == expectedString) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 42dd0024b2582..f238148e61c39 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -203,7 +203,7 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { } test("broadcast hint in SQL") { - import org.apache.spark.sql.catalyst.plans.logical.{ResolvedHint, Join} + import org.apache.spark.sql.catalyst.plans.logical.Join spark.range(10).createOrReplaceTempView("t") spark.range(10).createOrReplaceTempView("u") @@ -216,12 +216,12 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { val plan3 = sql(s"SELECT /*+ $name(v) */ * FROM t JOIN u ON t.id = u.id").queryExecution .optimizedPlan - assert(plan1.asInstanceOf[Join].left.isInstanceOf[ResolvedHint]) - assert(!plan1.asInstanceOf[Join].right.isInstanceOf[ResolvedHint]) - assert(!plan2.asInstanceOf[Join].left.isInstanceOf[ResolvedHint]) - assert(plan2.asInstanceOf[Join].right.isInstanceOf[ResolvedHint]) - assert(!plan3.asInstanceOf[Join].left.isInstanceOf[ResolvedHint]) - assert(!plan3.asInstanceOf[Join].right.isInstanceOf[ResolvedHint]) + assert(plan1.asInstanceOf[Join].hint.leftHint.get.broadcast) + assert(plan1.asInstanceOf[Join].hint.rightHint.isEmpty) + assert(plan2.asInstanceOf[Join].hint.leftHint.isEmpty) + assert(plan2.asInstanceOf[Join].hint.rightHint.get.broadcast) + assert(plan3.asInstanceOf[Join].hint.leftHint.isEmpty) + assert(plan3.asInstanceOf[Join].hint.rightHint.isEmpty) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala index 22279a3a43eff..771a9730247af 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.Join +import org.apache.spark.sql.catalyst.plans.logical.{Join, JoinHint} import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan, SparkPlanTest} import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.internal.SQLConf @@ -85,7 +85,8 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext { expectedAnswer: Seq[Row]): Unit = { def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = { - val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition)) + val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, + Inner, Some(condition), JoinHint.NONE) ExtractEquiJoinKeys.unapply(join) } @@ -102,7 +103,7 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext { } test(s"$testName using ShuffledHashJoin") { - extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => EnsureRequirements(left.sqlContext.sessionState.conf).apply( @@ -121,7 +122,7 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext { } testWithWholeStageCodegenOnAndOff(s"$testName using BroadcastHashJoin") { _ => - extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => EnsureRequirements(left.sqlContext.sessionState.conf).apply( @@ -140,7 +141,7 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext { } test(s"$testName using SortMergeJoin") { - extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => EnsureRequirements(left.sqlContext.sessionState.conf).apply( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index f5edd6bbd5e69..f99a278bb2427 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.Inner -import org.apache.spark.sql.catalyst.plans.logical.Join +import org.apache.spark.sql.catalyst.plans.logical.{Join, JoinHint} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.internal.SQLConf @@ -80,7 +80,8 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { expectedAnswer: Seq[Product]): Unit = { def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = { - val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition())) + val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, + Inner, Some(condition()), JoinHint.NONE) ExtractEquiJoinKeys.unapply(join) } @@ -128,7 +129,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { } testWithWholeStageCodegenOnAndOff(s"$testName using BroadcastHashJoin (build=left)") { _ => - extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => makeBroadcastHashJoin( @@ -140,7 +141,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { } testWithWholeStageCodegenOnAndOff(s"$testName using BroadcastHashJoin (build=right)") { _ => - extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => makeBroadcastHashJoin( @@ -152,7 +153,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { } test(s"$testName using ShuffledHashJoin (build=left)") { - extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => makeShuffledHashJoin( @@ -164,7 +165,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { } test(s"$testName using ShuffledHashJoin (build=right)") { - extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => makeShuffledHashJoin( @@ -176,7 +177,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { } testWithWholeStageCodegenOnAndOff(s"$testName using SortMergeJoin") { _ => - extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => makeSortMergeJoin(leftKeys, rightKeys, boundCondition, leftPlan, rightPlan), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index 513248dae48be..1f04fcf6ca451 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.Join +import org.apache.spark.sql.catalyst.plans.logical.{Join, JoinHint} import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.internal.SQLConf @@ -72,13 +72,14 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { expectedAnswer: Seq[Product]): Unit = { def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = { - val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition)) + val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, + Inner, Some(condition), JoinHint.NONE) ExtractEquiJoinKeys.unapply(join) } if (joinType != FullOuter) { test(s"$testName using ShuffledHashJoin") { - extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { val buildSide = if (joinType == LeftOuter) BuildRight else BuildLeft checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => @@ -99,7 +100,7 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { case RightOuter => BuildLeft case _ => fail(s"Unsupported join type $joinType") } - extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => BroadcastHashJoinExec( @@ -112,7 +113,7 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { } test(s"$testName using SortMergeJoin") { - extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => EnsureRequirements(spark.sessionState.conf).apply(