Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,40 +31,6 @@ import org.apache.spark.sql.internal.SQLConf
* Cost-based join reorder.
* We may have several join reorder algorithms in the future. This class is the entry of these
* algorithms, and chooses which one to use.
*
* Note that join strategy hints, e.g. the broadcast hint, do not interfere with the reordering.
* Such hints will be applied on the equivalent counterparts (i.e., join between the same relations
* regardless of the join order) of the original nodes after reordering.
* For example, the plan before reordering is like:
*
* Join
* / \
* Hint1 t4
* /
* Join
* / \
* Join t3
* / \
* Hint2 t2
* /
* t1
*
* The original join order as illustrated above is "((t1 JOIN t2) JOIN t3) JOIN t4", and after
* reordering, the new join order is "((t1 JOIN t3) JOIN t2) JOIN t4", so the new plan will be like:
*
* Join
* / \
* Hint1 t4
* /
* Join
* / \
* Join t2
* / \
* t1 t3
*
* "Hint1" is applied on "(t1 JOIN t3) JOIN t2" as it is equivalent to the original hinted node,
* "(t1 JOIN t2) JOIN t3"; while "Hint2" has disappeared from the new plan since there is no
* equivalent node to "t1 JOIN t2".
*/
object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper {

Expand All @@ -74,30 +40,24 @@ object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper {
if (!conf.cboEnabled || !conf.joinReorderEnabled) {
plan
} else {
// Use a map to track the hints on the join items.
val hintMap = new mutable.HashMap[AttributeSet, HintInfo]
val result = plan transformDown {
// Start reordering with a joinable item, which is an InnerLike join with conditions.
case j @ Join(_, _, _: InnerLike, Some(cond), _) =>
reorder(j, j.output, hintMap)
case p @ Project(projectList, Join(_, _, _: InnerLike, Some(cond), _))
if projectList.forall(_.isInstanceOf[Attribute]) =>
reorder(p, p.output, hintMap)
// Avoid reordering if a join hint is present.
case j @ Join(_, _, _: InnerLike, Some(cond), hint) if hint == JoinHint.NONE =>
reorder(j, j.output)
case p @ Project(projectList, Join(_, _, _: InnerLike, Some(cond), hint))
if projectList.forall(_.isInstanceOf[Attribute]) && hint == JoinHint.NONE =>
reorder(p, p.output)
}
// After reordering is finished, convert OrderedJoin back to Join.
result transform {
case OrderedJoin(left, right, jt, cond) =>
val joinHint = JoinHint(hintMap.get(left.outputSet), hintMap.get(right.outputSet))
Join(left, right, jt, cond, joinHint)
case OrderedJoin(left, right, jt, cond) => Join(left, right, jt, cond, JoinHint.NONE)
}
}
}

private def reorder(
plan: LogicalPlan,
output: Seq[Attribute],
hintMap: mutable.HashMap[AttributeSet, HintInfo]): LogicalPlan = {
val (items, conditions) = extractInnerJoins(plan, hintMap)
private def reorder(plan: LogicalPlan, output: Seq[Attribute]): LogicalPlan = {
val (items, conditions) = extractInnerJoins(plan)
val result =
// Do reordering if the number of items is appropriate and join conditions exist.
// We also need to check if costs of all items can be evaluated.
Expand All @@ -115,20 +75,16 @@ object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper {
* Extracts items of consecutive inner joins and join conditions.
* This method works for bushy trees and left/right deep trees.
*/
private def extractInnerJoins(
plan: LogicalPlan,
hintMap: mutable.HashMap[AttributeSet, HintInfo]): (Seq[LogicalPlan], Set[Expression]) = {
private def extractInnerJoins(plan: LogicalPlan): (Seq[LogicalPlan], Set[Expression]) = {
plan match {
case Join(left, right, _: InnerLike, Some(cond), hint) =>
hint.leftHint.foreach(hintMap.put(left.outputSet, _))
hint.rightHint.foreach(hintMap.put(right.outputSet, _))
val (leftPlans, leftConditions) = extractInnerJoins(left, hintMap)
val (rightPlans, rightConditions) = extractInnerJoins(right, hintMap)
case Join(left, right, _: InnerLike, Some(cond), _) =>
val (leftPlans, leftConditions) = extractInnerJoins(left)
val (rightPlans, rightConditions) = extractInnerJoins(right)
(leftPlans ++ rightPlans, splitConjunctivePredicates(cond).toSet ++
leftConditions ++ rightConditions)
case Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond), _))
if projectList.forall(_.isInstanceOf[Attribute]) =>
extractInnerJoins(j, hintMap)
extractInnerJoins(j)
case _ =>
(Seq(plan), Set())
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,11 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper {
*
* @param input a list of LogicalPlans to inner join and the type of inner join.
* @param conditions a list of condition for join.
* @param hintMap a map of relation output attribute sets to their corresponding hints.
*/
@tailrec
final def createOrderedJoin(
input: Seq[(LogicalPlan, InnerLike)],
conditions: Seq[Expression],
hintMap: Map[AttributeSet, HintInfo]): LogicalPlan = {
conditions: Seq[Expression]): LogicalPlan = {
assert(input.size >= 2)
if (input.size == 2) {
val (joinConditions, others) = conditions.partition(canEvaluateWithinJoin)
Expand All @@ -58,8 +56,8 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper {
case (Inner, Inner) => Inner
case (_, _) => Cross
}
val join = Join(left, right, innerJoinType, joinConditions.reduceLeftOption(And),
JoinHint(hintMap.get(left.outputSet), hintMap.get(right.outputSet)))
val join = Join(left, right, innerJoinType,
joinConditions.reduceLeftOption(And), JoinHint.NONE)
if (others.nonEmpty) {
Filter(others.reduceLeft(And), join)
} else {
Expand All @@ -82,27 +80,27 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper {
val joinedRefs = left.outputSet ++ right.outputSet
val (joinConditions, others) = conditions.partition(
e => e.references.subsetOf(joinedRefs) && canEvaluateWithinJoin(e))
val joined = Join(left, right, innerJoinType, joinConditions.reduceLeftOption(And),
JoinHint(hintMap.get(left.outputSet), hintMap.get(right.outputSet)))
val joined = Join(left, right, innerJoinType,
joinConditions.reduceLeftOption(And), JoinHint.NONE)

// should not have reference to same logical plan
createOrderedJoin(Seq((joined, Inner)) ++ rest.filterNot(_._1 eq right), others, hintMap)
createOrderedJoin(Seq((joined, Inner)) ++ rest.filterNot(_._1 eq right), others)
}
}

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case p @ ExtractFiltersAndInnerJoins(input, conditions, hintMap)
case p @ ExtractFiltersAndInnerJoins(input, conditions)
if input.size > 2 && conditions.nonEmpty =>
val reordered = if (SQLConf.get.starSchemaDetection && !SQLConf.get.cboEnabled) {
val starJoinPlan = StarSchemaDetection.reorderStarJoins(input, conditions)
if (starJoinPlan.nonEmpty) {
val rest = input.filterNot(starJoinPlan.contains(_))
createOrderedJoin(starJoinPlan ++ rest, conditions, hintMap)
createOrderedJoin(starJoinPlan ++ rest, conditions)
} else {
createOrderedJoin(input, conditions, hintMap)
createOrderedJoin(input, conditions)
}
} else {
createOrderedJoin(input, conditions, hintMap)
createOrderedJoin(input, conditions)
}

if (p.sameOutput(reordered)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,35 +166,27 @@ object ExtractFiltersAndInnerJoins extends PredicateHelper {
* was involved in an explicit cross join. Also returns the entire list of join conditions for
* the left-deep tree.
*/
def flattenJoin(
plan: LogicalPlan,
hintMap: mutable.HashMap[AttributeSet, HintInfo],
parentJoinType: InnerLike = Inner)
def flattenJoin(plan: LogicalPlan, parentJoinType: InnerLike = Inner)
: (Seq[(LogicalPlan, InnerLike)], Seq[Expression]) = plan match {
case Join(left, right, joinType: InnerLike, cond, hint) =>
val (plans, conditions) = flattenJoin(left, hintMap, joinType)
hint.leftHint.map(hintMap.put(left.outputSet, _))
hint.rightHint.map(hintMap.put(right.outputSet, _))
case Join(left, right, joinType: InnerLike, cond, hint) if hint == JoinHint.NONE =>
val (plans, conditions) = flattenJoin(left, joinType)
(plans ++ Seq((right, joinType)), conditions ++
cond.toSeq.flatMap(splitConjunctivePredicates))
case Filter(filterCondition, j @ Join(_, _, _: InnerLike, _, _)) =>
val (plans, conditions) = flattenJoin(j, hintMap)
case Filter(filterCondition, j @ Join(_, _, _: InnerLike, _, hint)) if hint == JoinHint.NONE =>
val (plans, conditions) = flattenJoin(j)
(plans, conditions ++ splitConjunctivePredicates(filterCondition))

case _ => (Seq((plan, parentJoinType)), Seq.empty)
}

def unapply(plan: LogicalPlan)
: Option[(Seq[(LogicalPlan, InnerLike)], Seq[Expression], Map[AttributeSet, HintInfo])]
: Option[(Seq[(LogicalPlan, InnerLike)], Seq[Expression])]
= plan match {
case f @ Filter(filterCondition, j @ Join(_, _, joinType: InnerLike, _, _)) =>
val hintMap = new mutable.HashMap[AttributeSet, HintInfo]
val flattened = flattenJoin(f, hintMap)
Some((flattened._1, flattened._2, hintMap.toMap))
case j @ Join(_, _, joinType, _, _) =>
val hintMap = new mutable.HashMap[AttributeSet, HintInfo]
val flattened = flattenJoin(j, hintMap)
Some((flattened._1, flattened._2, hintMap.toMap))
case f @ Filter(filterCondition, j @ Join(_, _, joinType: InnerLike, _, hint))
if hint == JoinHint.NONE =>
Some(flattenJoin(f))
case j @ Join(_, _, joinType, _, hint) if hint == JoinHint.NONE =>
Some(flattenJoin(j))
case _ => None
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class JoinOptimizationSuite extends PlanTest {
def testExtractCheckCross
(plan: LogicalPlan, expected: Option[(Seq[(LogicalPlan, InnerLike)], Seq[Expression])]) {
assert(
ExtractFiltersAndInnerJoins.unapply(plan) === expected.map(e => (e._1, e._2, Map.empty)))
ExtractFiltersAndInnerJoins.unapply(plan) === expected.map(e => (e._1, e._2)))
}

testExtract(x, None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -292,77 +292,56 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase {
assertEqualPlans(originalPlan, bestPlan)
}

test("hints preservation") {
// Apply hints if we find an equivalent node in the new plan, otherwise discard them.
test("don't reorder if hints present") {
val originalPlan =
t1.join(t2.hint("broadcast")).hint("broadcast").join(t4.join(t3).hint("broadcast"))
.where((nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")) &&
(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) &&
(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100")))

val bestPlan =
t1.join(t2.hint("broadcast"), Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")))
.hint("broadcast")
.join(
t4.join(t3, Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100")))
.hint("broadcast"),
Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")))
t1.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")))
.hint("broadcast")
.join(
t4.join(t3, Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100")))
.hint("broadcast"),
Inner,
Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")))

assertEqualPlans(originalPlan, bestPlan)
assertEqualPlans(originalPlan, originalPlan)

val originalPlan2 =
t1.join(t2).hint("broadcast").join(t3).hint("broadcast").join(t4.hint("broadcast"))
.where((nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")) &&
(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) &&
(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100")))

val bestPlan2 =
t1.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")))
.hint("broadcast")
.join(
t4.hint("broadcast")
.join(t3, Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))),
Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")))
.select(outputsOf(t1, t2, t3, t4): _*)
.join(t4, Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100")))
.hint("broadcast")
.join(t3, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")))

assertEqualPlans(originalPlan2, bestPlan2)
assertEqualPlans(originalPlan2, originalPlan2)
}

val originalPlan3 =
t1.join(t4).hint("broadcast")
.join(t2.hint("broadcast")).hint("broadcast")
.join(t3.hint("broadcast"))
.where((nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")) &&
(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) &&
(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100")))
test("reorder below and above the hint node") {
val originalPlan =
t1.join(t2).join(t3)
.where((nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) &&
(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100")))
.hint("broadcast").join(t4)

val bestPlan3 =
t1.join(t2.hint("broadcast"), Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")))
.join(
t4.join(t3.hint("broadcast"),
Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))),
Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")))
.select(outputsOf(t1, t4, t2, t3): _*)
val bestPlan =
t1.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100")))
.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")))
.select(outputsOf(t1, t2, t3): _*)
.hint("broadcast").join(t4)

assertEqualPlans(originalPlan3, bestPlan3)
assertEqualPlans(originalPlan, bestPlan)

val originalPlan4 =
t2.hint("broadcast")
.join(t4).hint("broadcast")
.join(t3.hint("broadcast")).hint("broadcast")
.join(t1)
.where((nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")) &&
(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) &&
(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100")))
val originalPlan2 =
t1.join(t2).join(t3)
.where((nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) &&
(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100")))
.join(t4.hint("broadcast"))

val bestPlan4 =
t1.join(t2.hint("broadcast"), Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")))
.join(
t4.join(t3.hint("broadcast"),
Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))),
Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")))
.select(outputsOf(t2, t4, t3, t1): _*)
val bestPlan2 =
t1.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100")))
.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")))
.select(outputsOf(t1, t2, t3): _*)
.join(t4.hint("broadcast"))

assertEqualPlans(originalPlan4, bestPlan4)
assertEqualPlans(originalPlan2, bestPlan2)
}

private def assertEqualPlans(
Expand Down
26 changes: 21 additions & 5 deletions sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql

import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext

class JoinHintSuite extends PlanTest with SharedSQLContext {
Expand Down Expand Up @@ -100,7 +101,7 @@ class JoinHintSuite extends PlanTest with SharedSQLContext {
}
}

test("hint preserved after join reorder") {
test("hints prevent join reorder") {
withTempView("a", "b", "c") {
df1.createOrReplaceTempView("a")
df2.createOrReplaceTempView("b")
Expand All @@ -118,12 +119,10 @@ class JoinHintSuite extends PlanTest with SharedSQLContext {
verifyJoinHint(
sql("select /*+ broadcast(a, c)*/ * from a, c, b " +
"where a.a1 = b.b1 and b.b1 = c.c1"),
JoinHint(
None,
Some(HintInfo(broadcast = true))) ::
JoinHint.NONE ::
JoinHint(
Some(HintInfo(broadcast = true)),
None):: Nil
Some(HintInfo(broadcast = true))):: Nil
)
verifyJoinHint(
sql("select /*+ broadcast(b, c)*/ * from a, c, b " +
Expand Down Expand Up @@ -199,4 +198,21 @@ class JoinHintSuite extends PlanTest with SharedSQLContext {
None) :: Nil
)
}

test("hints prevent cost-based join reorder") {
withSQLConf(SQLConf.CBO_ENABLED.key -> "true", SQLConf.JOIN_REORDER_ENABLED.key -> "true") {
val join = df.join(df, "id")
val broadcasted = join.hint("broadcast")
verifyJoinHint(
join.join(broadcasted, "id").join(broadcasted, "id"),
JoinHint(
None,
Some(HintInfo(broadcast = true))) ::
JoinHint(
None,
Some(HintInfo(broadcast = true))) ::
JoinHint.NONE :: JoinHint.NONE :: JoinHint.NONE :: Nil
)
}
}
}