Skip to content

Commit 8df584b

Browse files
Davies Liudavies
authored andcommitted
[SPARK-11982] [SQL] improve performance of cartesian product
This PR improve the performance of CartesianProduct by caching the result of right plan. After this patch, the query time of TPC-DS Q65 go down to 4 seconds from 28 minutes (420X faster). cc nongli Author: Davies Liu <[email protected]> Closes #9969 from davies/improve_cartesian.
1 parent 17275fa commit 8df584b

File tree

4 files changed

+139
-9
lines changed

4 files changed

+139
-9
lines changed

core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import java.io.File;
2222
import java.io.IOException;
2323
import java.util.LinkedList;
24+
import java.util.Queue;
2425

2526
import com.google.common.annotations.VisibleForTesting;
2627
import org.slf4j.Logger;
@@ -521,4 +522,66 @@ public long getKeyPrefix() {
521522
return upstream.getKeyPrefix();
522523
}
523524
}
525+
526+
/**
527+
* Returns a iterator, which will return the rows in the order as inserted.
528+
*
529+
* It is the caller's responsibility to call `cleanupResources()`
530+
* after consuming this iterator.
531+
*/
532+
public UnsafeSorterIterator getIterator() throws IOException {
533+
if (spillWriters.isEmpty()) {
534+
assert(inMemSorter != null);
535+
return inMemSorter.getIterator();
536+
} else {
537+
LinkedList<UnsafeSorterIterator> queue = new LinkedList<>();
538+
for (UnsafeSorterSpillWriter spillWriter : spillWriters) {
539+
queue.add(spillWriter.getReader(blockManager));
540+
}
541+
if (inMemSorter != null) {
542+
queue.add(inMemSorter.getIterator());
543+
}
544+
return new ChainedIterator(queue);
545+
}
546+
}
547+
548+
/**
549+
* Chain multiple UnsafeSorterIterator together as single one.
550+
*/
551+
class ChainedIterator extends UnsafeSorterIterator {
552+
553+
private final Queue<UnsafeSorterIterator> iterators;
554+
private UnsafeSorterIterator current;
555+
556+
public ChainedIterator(Queue<UnsafeSorterIterator> iterators) {
557+
assert iterators.size() > 0;
558+
this.iterators = iterators;
559+
this.current = iterators.remove();
560+
}
561+
562+
@Override
563+
public boolean hasNext() {
564+
while (!current.hasNext() && !iterators.isEmpty()) {
565+
current = iterators.remove();
566+
}
567+
return current.hasNext();
568+
}
569+
570+
@Override
571+
public void loadNext() throws IOException {
572+
current.loadNext();
573+
}
574+
575+
@Override
576+
public Object getBaseObject() { return current.getBaseObject(); }
577+
578+
@Override
579+
public long getBaseOffset() { return current.getBaseOffset(); }
580+
581+
@Override
582+
public int getRecordLength() { return current.getRecordLength(); }
583+
584+
@Override
585+
public long getKeyPrefix() { return current.getKeyPrefix(); }
586+
}
524587
}

core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,4 +226,11 @@ public SortedIterator getSortedIterator() {
226226
sorter.sort(array, 0, pos / 2, sortComparator);
227227
return new SortedIterator(pos / 2);
228228
}
229+
230+
/**
231+
* Returns an iterator over record pointers in original order (inserted).
232+
*/
233+
public SortedIterator getIterator() {
234+
return new SortedIterator(pos / 2);
235+
}
229236
}

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala

Lines changed: 68 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,75 @@
1717

1818
package org.apache.spark.sql.execution.joins
1919

20-
import org.apache.spark.rdd.RDD
20+
import org.apache.spark._
21+
import org.apache.spark.rdd.{CartesianPartition, CartesianRDD, RDD}
2122
import org.apache.spark.sql.catalyst.InternalRow
22-
import org.apache.spark.sql.catalyst.expressions.{Attribute, JoinedRow}
23-
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
23+
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
24+
import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow}
2425
import org.apache.spark.sql.execution.metric.SQLMetrics
26+
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
27+
import org.apache.spark.util.CompletionIterator
28+
import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
29+
30+
31+
/**
32+
* An optimized CartesianRDD for UnsafeRow, which will cache the rows from second child RDD,
33+
* will be much faster than building the right partition for every row in left RDD, it also
34+
* materialize the right RDD (in case of the right RDD is nondeterministic).
35+
*/
36+
private[spark]
37+
class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numFieldsOfRight: Int)
38+
extends CartesianRDD[UnsafeRow, UnsafeRow](left.sparkContext, left, right) {
39+
40+
override def compute(split: Partition, context: TaskContext): Iterator[(UnsafeRow, UnsafeRow)] = {
41+
// We will not sort the rows, so prefixComparator and recordComparator are null.
42+
val sorter = UnsafeExternalSorter.create(
43+
context.taskMemoryManager(),
44+
SparkEnv.get.blockManager,
45+
context,
46+
null,
47+
null,
48+
1024,
49+
SparkEnv.get.memoryManager.pageSizeBytes)
50+
51+
val partition = split.asInstanceOf[CartesianPartition]
52+
for (y <- rdd2.iterator(partition.s2, context)) {
53+
sorter.insertRecord(y.getBaseObject, y.getBaseOffset, y.getSizeInBytes, 0)
54+
}
55+
56+
// Create an iterator from sorter and wrapper it as Iterator[UnsafeRow]
57+
def createIter(): Iterator[UnsafeRow] = {
58+
val iter = sorter.getIterator
59+
val unsafeRow = new UnsafeRow
60+
new Iterator[UnsafeRow] {
61+
override def hasNext: Boolean = {
62+
iter.hasNext
63+
}
64+
override def next(): UnsafeRow = {
65+
iter.loadNext()
66+
unsafeRow.pointTo(iter.getBaseObject, iter.getBaseOffset, numFieldsOfRight,
67+
iter.getRecordLength)
68+
unsafeRow
69+
}
70+
}
71+
}
72+
73+
val resultIter =
74+
for (x <- rdd1.iterator(partition.s1, context);
75+
y <- createIter()) yield (x, y)
76+
CompletionIterator[(UnsafeRow, UnsafeRow), Iterator[(UnsafeRow, UnsafeRow)]](
77+
resultIter, sorter.cleanupResources)
78+
}
79+
}
2580

2681

2782
case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode {
2883
override def output: Seq[Attribute] = left.output ++ right.output
2984

85+
override def canProcessSafeRows: Boolean = false
86+
override def canProcessUnsafeRows: Boolean = true
87+
override def outputsUnsafeRows: Boolean = true
88+
3089
override private[sql] lazy val metrics = Map(
3190
"numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"),
3291
"numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"),
@@ -39,18 +98,19 @@ case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNod
3998

4099
val leftResults = left.execute().map { row =>
41100
numLeftRows += 1
42-
row.copy()
101+
row.asInstanceOf[UnsafeRow]
43102
}
44103
val rightResults = right.execute().map { row =>
45104
numRightRows += 1
46-
row.copy()
105+
row.asInstanceOf[UnsafeRow]
47106
}
48107

49-
leftResults.cartesian(rightResults).mapPartitionsInternal { iter =>
50-
val joinedRow = new JoinedRow
108+
val pair = new UnsafeCartesianRDD(leftResults, rightResults, right.output.size)
109+
pair.mapPartitionsInternal { iter =>
110+
val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema)
51111
iter.map { r =>
52112
numOutputRows += 1
53-
joinedRow(r._1, r._2)
113+
joiner.join(r._1, r._2)
54114
}
55115
}
56116
}

sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
317317
testSparkPlanMetrics(df, 1, Map(
318318
1L -> ("CartesianProduct", Map(
319319
"number of left rows" -> 12L, // left needs to be scanned twice
320-
"number of right rows" -> 12L, // right is read 6 times
320+
"number of right rows" -> 4L, // right is read twice
321321
"number of output rows" -> 12L)))
322322
)
323323
}

0 commit comments

Comments
 (0)