@@ -44,6 +44,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
4444 }
4545 }
4646
47+ // TODO: add comments to explain optimization
4748 /**
4849 * Uses the ExtractEquiJoinKeys pattern to find joins where at least some of the predicates can be
4950 * evaluated by matching hash keys.
@@ -82,11 +83,13 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
8283 if left.statistics.sizeInBytes <= sqlContext.autoConvertJoinSize =>
8384 broadcastHashJoin(leftKeys, rightKeys, left, right, condition, BuildLeft )
8485
85- // TODO: use optimization here as well
8686 case ExtractEquiJoinKeys (Inner , leftKeys, rightKeys, condition, left, right) =>
87+ val buildSide =
88+ if (right.statistics.sizeInBytes <= sqlContext.autoConvertJoinSize) BuildRight
89+ else BuildLeft
8790 val hashJoin =
8891 execution.ShuffledHashJoin (
89- leftKeys, rightKeys, BuildRight , planLater(left), planLater(right))
92+ leftKeys, rightKeys, buildSide , planLater(left), planLater(right))
9093 condition.map(Filter (_, hashJoin)).getOrElse(hashJoin) :: Nil
9194
9295 case _ => Nil
@@ -147,11 +150,16 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
147150 }
148151 }
149152
153+ // TODO: add comments to explain optimization
150154 object BroadcastNestedLoopJoin extends Strategy {
151155 def apply (plan : LogicalPlan ): Seq [SparkPlan ] = plan match {
152156 case logical.Join (left, right, joinType, condition) =>
157+ val (streamed, broadcast) =
158+ if (right.statistics.sizeInBytes <= sqlContext.autoConvertJoinSize)
159+ (planLater(left), planLater(right))
160+ else (planLater(right), planLater(left))
153161 execution.BroadcastNestedLoopJoin (
154- planLater(left), planLater(right) , joinType, condition)(sqlContext) :: Nil
162+ streamed, broadcast , joinType, condition)(sqlContext) :: Nil
155163 case _ => Nil
156164 }
157165 }
0 commit comments