1717
1818package org .apache .spark .sql .execution .joins
1919
20- import org .apache .spark .rdd .RDD
20+ import org .apache .spark ._
21+ import org .apache .spark .rdd .{CartesianPartition , CartesianRDD , RDD }
2122import org .apache .spark .sql .catalyst .InternalRow
22- import org .apache .spark .sql .catalyst .expressions .{ Attribute , JoinedRow }
23- import org .apache .spark .sql .execution .{ BinaryNode , SparkPlan }
23+ import org .apache .spark .sql .catalyst .expressions .codegen . GenerateUnsafeRowJoiner
24+ import org .apache .spark .sql .catalyst . expressions .{ Attribute , UnsafeRow }
2425import org .apache .spark .sql .execution .metric .SQLMetrics
26+ import org .apache .spark .sql .execution .{BinaryNode , SparkPlan }
27+ import org .apache .spark .util .CompletionIterator
28+ import org .apache .spark .util .collection .unsafe .sort .UnsafeExternalSorter
29+
30+
31+ /**
32+ * An optimized CartesianRDD for UnsafeRow, which will cache the rows from second child RDD,
33+ * will be much faster than building the right partition for every row in left RDD, it also
34+ * materialize the right RDD (in case of the right RDD is nondeterministic).
35+ */
36+ private [spark]
37+ class UnsafeCartesianRDD (left : RDD [UnsafeRow ], right : RDD [UnsafeRow ], numFieldsOfRight : Int )
38+ extends CartesianRDD [UnsafeRow , UnsafeRow ](left.sparkContext, left, right) {
39+
40+ override def compute (split : Partition , context : TaskContext ): Iterator [(UnsafeRow , UnsafeRow )] = {
41+ // We will not sort the rows, so prefixComparator and recordComparator are null.
42+ val sorter = UnsafeExternalSorter .create(
43+ context.taskMemoryManager(),
44+ SparkEnv .get.blockManager,
45+ context,
46+ null ,
47+ null ,
48+ 1024 ,
49+ SparkEnv .get.memoryManager.pageSizeBytes)
50+
51+ val partition = split.asInstanceOf [CartesianPartition ]
52+ for (y <- rdd2.iterator(partition.s2, context)) {
53+ sorter.insertRecord(y.getBaseObject, y.getBaseOffset, y.getSizeInBytes, 0 )
54+ }
55+
56+ // Create an iterator from sorter and wrapper it as Iterator[UnsafeRow]
57+ def createIter (): Iterator [UnsafeRow ] = {
58+ val iter = sorter.getIterator
59+ val unsafeRow = new UnsafeRow
60+ new Iterator [UnsafeRow ] {
61+ override def hasNext : Boolean = {
62+ iter.hasNext
63+ }
64+ override def next (): UnsafeRow = {
65+ iter.loadNext()
66+ unsafeRow.pointTo(iter.getBaseObject, iter.getBaseOffset, numFieldsOfRight,
67+ iter.getRecordLength)
68+ unsafeRow
69+ }
70+ }
71+ }
72+
73+ val resultIter =
74+ for (x <- rdd1.iterator(partition.s1, context);
75+ y <- createIter()) yield (x, y)
76+ CompletionIterator [(UnsafeRow , UnsafeRow ), Iterator [(UnsafeRow , UnsafeRow )]](
77+ resultIter, sorter.cleanupResources)
78+ }
79+ }
2580
2681
2782case class CartesianProduct (left : SparkPlan , right : SparkPlan ) extends BinaryNode {
2883 override def output : Seq [Attribute ] = left.output ++ right.output
2984
85+ override def canProcessSafeRows : Boolean = false
86+ override def canProcessUnsafeRows : Boolean = true
87+ override def outputsUnsafeRows : Boolean = true
88+
3089 override private [sql] lazy val metrics = Map (
3190 " numLeftRows" -> SQLMetrics .createLongMetric(sparkContext, " number of left rows" ),
3291 " numRightRows" -> SQLMetrics .createLongMetric(sparkContext, " number of right rows" ),
@@ -39,18 +98,19 @@ case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNod
3998
4099 val leftResults = left.execute().map { row =>
41100 numLeftRows += 1
42- row.copy()
101+ row.asInstanceOf [ UnsafeRow ]
43102 }
44103 val rightResults = right.execute().map { row =>
45104 numRightRows += 1
46- row.copy()
105+ row.asInstanceOf [ UnsafeRow ]
47106 }
48107
49- leftResults.cartesian(rightResults).mapPartitionsInternal { iter =>
50- val joinedRow = new JoinedRow
108+ val pair = new UnsafeCartesianRDD (leftResults, rightResults, right.output.size)
109+ pair.mapPartitionsInternal { iter =>
110+ val joiner = GenerateUnsafeRowJoiner .create(left.schema, right.schema)
51111 iter.map { r =>
52112 numOutputRows += 1
53- joinedRow (r._1, r._2)
113+ joiner.join (r._1, r._2)
54114 }
55115 }
56116 }
0 commit comments