Skip to content

Commit 23deb4b

Browse files
committed
Update
1 parent 61d1a7e commit 23deb4b

File tree

2 files changed

+30
-16
lines changed

2 files changed

+30
-16
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,10 +213,22 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
213213
object CartesianProduct extends Strategy {
214214
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
215215
case logical.Join(left, right, _, None) =>
216-
execution.joins.CartesianProduct(planLater(left), planLater(right)) :: Nil
216+
val buildSide =
217+
if (left.statistics.sizeInBytes <= right.statistics.sizeInBytes) {
218+
joins.BuildRight
219+
} else {
220+
joins.BuildLeft
221+
}
222+
execution.joins.CartesianProduct(planLater(left), planLater(right), buildSide) :: Nil
217223
case logical.Join(left, right, Inner, Some(condition)) =>
224+
val buildSide =
225+
if (left.statistics.sizeInBytes <= right.statistics.sizeInBytes) {
226+
joins.BuildRight
227+
} else {
228+
joins.BuildLeft
229+
}
218230
execution.Filter(condition,
219-
execution.joins.CartesianProduct(planLater(left), planLater(right))) :: Nil
231+
execution.joins.CartesianProduct(planLater(left), planLater(right), buildSide)) :: Nil
220232
case _ => Nil
221233
}
222234
}

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,24 +27,26 @@ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
2727
* :: DeveloperApi ::
2828
*/
2929
@DeveloperApi
30-
case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode {
31-
override def output: Seq[Attribute] = left.output ++ right.output
30+
case class CartesianProduct(
31+
left: SparkPlan,
32+
right: SparkPlan,
33+
buildSide: BuildSide) extends BinaryNode {
3234

33-
protected override def doExecute(): RDD[InternalRow] = {
34-
val leftResults = left.execute().map(_.copy())
35-
val rightResults = right.execute().map(_.copy())
35+
private val (streamed, broadcast) = buildSide match {
36+
case BuildRight => (left, right)
37+
case BuildLeft => (right, left)
38+
}
3639

37-
val cartesianRdd = if (leftResults.partitions.size > rightResults.partitions.size) {
38-
rightResults.cartesian(leftResults).mapPartitions { iter =>
39-
iter.map(tuple => (tuple._2, tuple._1))
40-
}
41-
} else {
42-
leftResults.cartesian(rightResults)
43-
}
40+
override def output: Seq[Attribute] = left.output ++ right.output
4441

45-
cartesianRdd.mapPartitions { iter =>
42+
protected override def doExecute(): RDD[InternalRow] = {
43+
val broadcastedRelation = sparkContext.broadcast(broadcast.execute().map(_.copy()))
44+
broadcastedRelation.value.cartesian(streamed.execute().map(_.copy())).mapPartitions{ iter =>
4645
val joinedRow = new JoinedRow
47-
iter.map(r => joinedRow(r._1, r._2))
46+
buildSide match {
47+
case BuildRight => iter.map(r => joinedRow(r._1, r._2))
48+
case BuildLeft => iter.map(r => joinedRow(r._2, r._1))
49+
}
4850
}
4951
}
5052
}

0 commit comments

Comments
 (0)