1717
1818package org .apache .spark .sql .execution .joins
1919
20- import org .apache .spark .rdd .RDD
20+ import org .apache .spark ._
21+ import org .apache .spark .rdd .{CartesianPartition , CartesianRDD , RDD }
2122import org .apache .spark .sql .catalyst .InternalRow
22- import org .apache .spark .sql .catalyst .expressions .{ Attribute , JoinedRow }
23- import org .apache .spark .sql .execution .{ BinaryNode , SparkPlan }
23+ import org .apache .spark .sql .catalyst .expressions .codegen . GenerateUnsafeRowJoiner
24+ import org .apache .spark .sql .catalyst . expressions .{ Attribute , UnsafeRow }
2425import org .apache .spark .sql .execution .metric .SQLMetrics
26+ import org .apache .spark .sql .execution .{BinaryNode , SparkPlan }
27+ import org .apache .spark .util .CompletionIterator
28+ import org .apache .spark .util .collection .unsafe .sort .UnsafeExternalSorter
29+
30+
31+ private [spark]
32+ class UnsafeCartesianRDD (rdd1 : RDD [UnsafeRow ], rdd2 : RDD [UnsafeRow ])
33+ extends CartesianRDD [UnsafeRow , UnsafeRow ](rdd1.sparkContext, rdd1, rdd2) {
34+
35+ override def compute (split : Partition , context : TaskContext ): Iterator [(UnsafeRow , UnsafeRow )] = {
36+ val sorter = UnsafeExternalSorter .create(
37+ context.taskMemoryManager(),
38+ SparkEnv .get.blockManager,
39+ context,
40+ null ,
41+ null ,
42+ 1024 ,
43+ SparkEnv .get.memoryManager.pageSizeBytes)
44+
45+ val currSplit = split.asInstanceOf [CartesianPartition ]
46+ var numFields = 0
47+ for (y <- rdd2.iterator(currSplit.s2, context)) {
48+ numFields = y.numFields()
49+ sorter.insertRecord(y.getBaseObject, y.getBaseOffset, y.getSizeInBytes, 0 )
50+ }
51+
52+ def createIter (): Iterator [UnsafeRow ] = {
53+ val iter = sorter.getIterator
54+ val unsafeRow = new UnsafeRow
55+ new Iterator [UnsafeRow ] {
56+ override def hasNext : Boolean = {
57+ iter.hasNext
58+ }
59+ override def next (): UnsafeRow = {
60+ iter.loadNext()
61+ unsafeRow.pointTo(iter.getBaseObject, iter.getBaseOffset, numFields, iter.getRecordLength)
62+ unsafeRow
63+ }
64+ }
65+ }
66+
67+ val resultIter =
68+ for (x <- rdd1.iterator(currSplit.s1, context);
69+ y <- createIter()) yield (x, y)
70+ CompletionIterator [(UnsafeRow , UnsafeRow ), Iterator [(UnsafeRow , UnsafeRow )]](
71+ resultIter, sorter.cleanupResources)
72+ }
73+ }
2574
2675
2776case class CartesianProduct (left : SparkPlan , right : SparkPlan ) extends BinaryNode {
2877 override def output : Seq [Attribute ] = left.output ++ right.output
2978
79+ override def canProcessSafeRows : Boolean = false
80+ override def canProcessUnsafeRows : Boolean = true
81+ override def outputsUnsafeRows : Boolean = true
82+
3083 override private [sql] lazy val metrics = Map (
3184 " numLeftRows" -> SQLMetrics .createLongMetric(sparkContext, " number of left rows" ),
3285 " numRightRows" -> SQLMetrics .createLongMetric(sparkContext, " number of right rows" ),
@@ -39,18 +92,19 @@ case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNod
3992
4093 val leftResults = left.execute().map { row =>
4194 numLeftRows += 1
42- row.copy()
95+ row.asInstanceOf [ UnsafeRow ]
4396 }
4497 val rightResults = right.execute().map { row =>
4598 numRightRows += 1
46- row.copy()
99+ row.asInstanceOf [ UnsafeRow ]
47100 }
48101
49- leftResults.cartesian(rightResults).mapPartitionsInternal { iter =>
50- val joinedRow = new JoinedRow
102+ val pair = new UnsafeCartesianRDD (leftResults, rightResults)
103+ pair.mapPartitionsInternal { iter =>
104+ val joiner = GenerateUnsafeRowJoiner .create(left.schema, right.schema)
51105 iter.map { r =>
52106 numOutputRows += 1
53- joinedRow (r._1, r._2)
107+ joiner.join (r._1, r._2)
54108 }
55109 }
56110 }
0 commit comments