Skip to content

Commit 162268c

Browse files
author
Davies Liu
committed
improve performance of cartesian product
1 parent 52bc25c commit 162268c

File tree

3 files changed

+123
-8
lines changed

3 files changed

+123
-8
lines changed

core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import java.io.File;
2222
import java.io.IOException;
2323
import java.util.LinkedList;
24+
import java.util.Queue;
2425

2526
import com.google.common.annotations.VisibleForTesting;
2627
import org.slf4j.Logger;
@@ -519,4 +520,60 @@ public long getKeyPrefix() {
519520
return upstream.getKeyPrefix();
520521
}
521522
}
523+
524+
/**
525+
* Returns a iterator. It is the caller's responsibility to call `cleanupResources()`
526+
* after consuming this iterator.
527+
*/
528+
public UnsafeSorterIterator getIterator() throws IOException {
529+
if (spillWriters.isEmpty()) {
530+
assert(inMemSorter != null);
531+
return inMemSorter.getIterator();
532+
} else {
533+
Queue<UnsafeSorterIterator> queue = new LinkedList<>();
534+
for (UnsafeSorterSpillWriter spillWriter : spillWriters) {
535+
queue.add(spillWriter.getReader(blockManager));
536+
}
537+
if (inMemSorter != null) {
538+
queue.add(inMemSorter.getIterator());
539+
}
540+
return new ChainedIterator(queue);
541+
}
542+
}
543+
544+
class ChainedIterator extends UnsafeSorterIterator {
545+
private final Queue<UnsafeSorterIterator> iterators;
546+
private UnsafeSorterIterator current = null;
547+
public ChainedIterator(Queue<UnsafeSorterIterator> iters) {
548+
this.iterators = iters;
549+
this.current = iters.remove();
550+
}
551+
552+
@Override
553+
public boolean hasNext() {
554+
if (!current.hasNext()) {
555+
if (!iterators.isEmpty()) {
556+
current = iterators.remove();
557+
}
558+
}
559+
return current.hasNext();
560+
}
561+
562+
@Override
563+
public void loadNext() throws IOException {
564+
current.loadNext();
565+
}
566+
567+
@Override
568+
public Object getBaseObject() { return current.getBaseObject(); }
569+
570+
@Override
571+
public long getBaseOffset() { return current.getBaseOffset(); }
572+
573+
@Override
574+
public int getRecordLength() { return current.getRecordLength(); }
575+
576+
@Override
577+
public long getKeyPrefix() { return current.getKeyPrefix(); }
578+
}
522579
}

core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,4 +231,8 @@ public SortedIterator getSortedIterator() {
231231
sorter.sort(array, 0, pos / 2, sortComparator);
232232
return new SortedIterator(memoryManager, pos, array);
233233
}
234+
235+
public SortedIterator getIterator() {
236+
return new SortedIterator(memoryManager, pos, array);
237+
}
234238
}

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala

Lines changed: 62 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,69 @@
1717

1818
package org.apache.spark.sql.execution.joins
1919

20-
import org.apache.spark.rdd.RDD
20+
import org.apache.spark._
21+
import org.apache.spark.rdd.{CartesianPartition, CartesianRDD, RDD}
2122
import org.apache.spark.sql.catalyst.InternalRow
22-
import org.apache.spark.sql.catalyst.expressions.{Attribute, JoinedRow}
23-
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
23+
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
24+
import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow}
2425
import org.apache.spark.sql.execution.metric.SQLMetrics
26+
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
27+
import org.apache.spark.util.CompletionIterator
28+
import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
29+
30+
31+
private[spark]
32+
class UnsafeCartesianRDD(rdd1 : RDD[UnsafeRow], rdd2 : RDD[UnsafeRow])
33+
extends CartesianRDD[UnsafeRow, UnsafeRow](rdd1.sparkContext, rdd1, rdd2) {
34+
35+
override def compute(split: Partition, context: TaskContext): Iterator[(UnsafeRow, UnsafeRow)] = {
36+
val sorter = UnsafeExternalSorter.create(
37+
context.taskMemoryManager(),
38+
SparkEnv.get.blockManager,
39+
context,
40+
null,
41+
null,
42+
1024,
43+
SparkEnv.get.memoryManager.pageSizeBytes)
44+
45+
val currSplit = split.asInstanceOf[CartesianPartition]
46+
var numFields = 0
47+
for (y <- rdd2.iterator(currSplit.s2, context)) {
48+
numFields = y.numFields()
49+
sorter.insertRecord(y.getBaseObject, y.getBaseOffset, y.getSizeInBytes, 0)
50+
}
51+
52+
def createIter(): Iterator[UnsafeRow] = {
53+
val iter = sorter.getIterator
54+
val unsafeRow = new UnsafeRow
55+
new Iterator[UnsafeRow] {
56+
override def hasNext: Boolean = {
57+
iter.hasNext
58+
}
59+
override def next(): UnsafeRow = {
60+
iter.loadNext()
61+
unsafeRow.pointTo(iter.getBaseObject, iter.getBaseOffset, numFields, iter.getRecordLength)
62+
unsafeRow
63+
}
64+
}
65+
}
66+
67+
val resultIter =
68+
for (x <- rdd1.iterator(currSplit.s1, context);
69+
y <- createIter()) yield (x, y)
70+
CompletionIterator[(UnsafeRow, UnsafeRow), Iterator[(UnsafeRow, UnsafeRow)]](
71+
resultIter, sorter.cleanupResources)
72+
}
73+
}
2574

2675

2776
case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode {
2877
override def output: Seq[Attribute] = left.output ++ right.output
2978

79+
override def canProcessSafeRows: Boolean = false
80+
override def canProcessUnsafeRows: Boolean = true
81+
override def outputsUnsafeRows: Boolean = true
82+
3083
override private[sql] lazy val metrics = Map(
3184
"numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"),
3285
"numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"),
@@ -39,18 +92,19 @@ case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNod
3992

4093
val leftResults = left.execute().map { row =>
4194
numLeftRows += 1
42-
row.copy()
95+
row.asInstanceOf[UnsafeRow]
4396
}
4497
val rightResults = right.execute().map { row =>
4598
numRightRows += 1
46-
row.copy()
99+
row.asInstanceOf[UnsafeRow]
47100
}
48101

49-
leftResults.cartesian(rightResults).mapPartitionsInternal { iter =>
50-
val joinedRow = new JoinedRow
102+
val pair = new UnsafeCartesianRDD(leftResults, rightResults)
103+
pair.mapPartitionsInternal { iter =>
104+
val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema)
51105
iter.map { r =>
52106
numOutputRows += 1
53-
joinedRow(r._1, r._2)
107+
joiner.join(r._1, r._2)
54108
}
55109
}
56110
}

0 commit comments

Comments
 (0)