@@ -27,24 +27,26 @@ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
2727 * :: DeveloperApi ::
2828 */
2929@ DeveloperApi
30- case class CartesianProduct (left : SparkPlan , right : SparkPlan ) extends BinaryNode {
31- override def output : Seq [Attribute ] = left.output ++ right.output
30+ case class CartesianProduct (
31+ left : SparkPlan ,
32+ right : SparkPlan ,
33+ buildSide : BuildSide ) extends BinaryNode {
3234
33- protected override def doExecute (): RDD [InternalRow ] = {
34- val leftResults = left.execute().map(_.copy())
35- val rightResults = right.execute().map(_.copy())
35+ private val (streamed, broadcast) = buildSide match {
36+ case BuildRight => (left, right)
37+ case BuildLeft => (right, left)
38+ }
3639
37- val cartesianRdd = if (leftResults.partitions.size > rightResults.partitions.size) {
38- rightResults.cartesian(leftResults).mapPartitions { iter =>
39- iter.map(tuple => (tuple._2, tuple._1))
40- }
41- } else {
42- leftResults.cartesian(rightResults)
43- }
40+ override def output : Seq [Attribute ] = left.output ++ right.output
4441
45- cartesianRdd.mapPartitions { iter =>
42+ protected override def doExecute (): RDD [InternalRow ] = {
43+ val broadcastedRelation = sparkContext.broadcast(broadcast.execute().map(_.copy()))
44+ broadcastedRelation.value.cartesian(streamed.execute().map(_.copy())).mapPartitions{ iter =>
4645 val joinedRow = new JoinedRow
47- iter.map(r => joinedRow(r._1, r._2))
46+ buildSide match {
47+ case BuildRight => iter.map(r => joinedRow(r._1, r._2))
48+ case BuildLeft => iter.map(r => joinedRow(r._2, r._1))
49+ }
4850 }
4951 }
5052}
0 commit comments