Skip to content

Commit 9cdd867

Browse files
Davies Liudavies
authored andcommitted
[SPARK-13373] [SQL] generate sort merge join
## What changes were proposed in this pull request? Generates code for SortMergeJoin. ## How was the this patch tested? Unit tests and manually tested with TPCDS Q72, which showed 70% performance improvements (from 42s to 25s), but micro benchmark only show minor improvements, it may depends the distribution of data and number of columns. Author: Davies Liu <[email protected]> Closes #11248 from davies/gen_smj.
1 parent c481bdf commit 9cdd867

File tree

11 files changed

+360
-52
lines changed

11 files changed

+360
-52
lines changed

core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ private[spark] class DiskBlockObjectWriter(
203203
numRecordsWritten += 1
204204
writeMetrics.incRecordsWritten(1)
205205

206+
// TODO: call updateBytesWritten() less frequently.
206207
if (numRecordsWritten % 32 == 0) {
207208
updateBytesWritten()
208209
}

sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,9 @@
2929
/**
3030
* An iterator interface used to pull the output from generated function for multiple operators
3131
* (whole stage codegen).
32-
*
33-
* TODO: replaced it by batched columnar format.
3432
*/
35-
public class BufferedRowIterator {
33+
public abstract class BufferedRowIterator {
3634
protected LinkedList<InternalRow> currentRows = new LinkedList<>();
37-
protected Iterator<InternalRow> input;
3835
// used when there is no column in output
3936
protected UnsafeRow unsafeRow = new UnsafeRow(0);
4037

@@ -49,8 +46,16 @@ public InternalRow next() {
4946
return currentRows.remove();
5047
}
5148

52-
public void setInput(Iterator<InternalRow> iter) {
53-
input = iter;
49+
/**
50+
* Initializes from array of iterators of InternalRow.
51+
*/
52+
public abstract void init(Iterator<InternalRow> iters[]);
53+
54+
/**
55+
* Append a row to currentRows.
56+
*/
57+
protected void append(InternalRow row) {
58+
currentRows.add(row);
5459
}
5560

5661
/**
@@ -74,9 +79,5 @@ protected void incPeakExecutionMemory(long size) {
7479
*
7580
* After it's called, if currentRow is still null, it means no more rows left.
7681
*/
77-
protected void processNext() throws IOException {
78-
if (input.hasNext()) {
79-
currentRows.add(input.next());
80-
}
81-
}
82+
protected abstract void processNext() throws IOException;
8283
}

sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,8 @@ case class Expand(
8585
}
8686
}
8787

88-
override def upstream(): RDD[InternalRow] = {
89-
child.asInstanceOf[CodegenSupport].upstream()
88+
override def upstreams(): Seq[RDD[InternalRow]] = {
89+
child.asInstanceOf[CodegenSupport].upstreams()
9090
}
9191

9292
protected override def doProduce(ctx: CodegenContext): String = {

sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala

Lines changed: 49 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.plans.physical.Partitioning
2929
import org.apache.spark.sql.catalyst.rules.Rule
3030
import org.apache.spark.sql.catalyst.util.toCommentSafeString
3131
import org.apache.spark.sql.execution.aggregate.TungstenAggregate
32-
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, BuildLeft, BuildRight}
32+
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, BuildLeft, BuildRight, SortMergeJoin}
3333
import org.apache.spark.sql.execution.metric.LongSQLMetricValue
3434

3535
/**
@@ -40,7 +40,8 @@ trait CodegenSupport extends SparkPlan {
4040
/** Prefix used in the current operator's variable names. */
4141
private def variablePrefix: String = this match {
4242
case _: TungstenAggregate => "agg"
43-
case _: BroadcastHashJoin => "join"
43+
case _: BroadcastHashJoin => "bhj"
44+
case _: SortMergeJoin => "smj"
4445
case _ => nodeName.toLowerCase
4546
}
4647

@@ -68,9 +69,11 @@ trait CodegenSupport extends SparkPlan {
6869
private var parent: CodegenSupport = null
6970

7071
/**
71-
* Returns the RDD of InternalRow which generates the input rows.
72+
* Returns all the RDDs of InternalRow which generates the input rows.
73+
*
74+
* Note: right now we support up to two RDDs.
7275
*/
73-
def upstream(): RDD[InternalRow]
76+
def upstreams(): Seq[RDD[InternalRow]]
7477

7578
/**
7679
* Returns Java source code to process the rows from upstream.
@@ -179,19 +182,23 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
179182

180183
override def supportCodegen: Boolean = false
181184

182-
override def upstream(): RDD[InternalRow] = {
183-
child.execute()
185+
override def upstreams(): Seq[RDD[InternalRow]] = {
186+
child.execute() :: Nil
184187
}
185188

186189
override def doProduce(ctx: CodegenContext): String = {
190+
val input = ctx.freshName("input")
191+
// Right now, InputAdapter is only used when there is one upstream.
192+
ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];")
193+
187194
val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true))
188195
val row = ctx.freshName("row")
189196
ctx.INPUT_ROW = row
190197
ctx.currentVars = null
191198
val columns = exprs.map(_.gen(ctx))
192199
s"""
193-
| while (input.hasNext()) {
194-
| InternalRow $row = (InternalRow) input.next();
200+
| while ($input.hasNext()) {
201+
| InternalRow $row = (InternalRow) $input.next();
195202
| ${columns.map(_.code).mkString("\n").trim}
196203
| ${consume(ctx, columns).trim}
197204
| if (shouldStop()) {
@@ -215,7 +222,7 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
215222
*
216223
* -> execute()
217224
* |
218-
* doExecute() ---------> upstream() -------> upstream() ------> execute()
225+
* doExecute() ---------> upstreams() -------> upstreams() ------> execute()
219226
* |
220227
* -----------------> produce()
221228
* |
@@ -267,6 +274,9 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
267274

268275
public GeneratedIterator(Object[] references) {
269276
this.references = references;
277+
}
278+
279+
public void init(scala.collection.Iterator inputs[]) {
270280
${ctx.initMutableStates()}
271281
}
272282

@@ -283,19 +293,33 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
283293
// println(s"${CodeFormatter.format(cleanedSource)}")
284294
CodeGenerator.compile(cleanedSource)
285295

286-
plan.upstream().mapPartitions { iter =>
287-
288-
val clazz = CodeGenerator.compile(source)
289-
val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
290-
buffer.setInput(iter)
291-
new Iterator[InternalRow] {
292-
override def hasNext: Boolean = buffer.hasNext
293-
override def next: InternalRow = buffer.next()
296+
val rdds = plan.upstreams()
297+
assert(rdds.size <= 2, "Up to two upstream RDDs can be supported")
298+
if (rdds.length == 1) {
299+
rdds.head.mapPartitions { iter =>
300+
val clazz = CodeGenerator.compile(cleanedSource)
301+
val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
302+
buffer.init(Array(iter))
303+
new Iterator[InternalRow] {
304+
override def hasNext: Boolean = buffer.hasNext
305+
override def next: InternalRow = buffer.next()
306+
}
307+
}
308+
} else {
309+
// Right now, we support up to two upstreams.
310+
rdds.head.zipPartitions(rdds(1)) { (leftIter, rightIter) =>
311+
val clazz = CodeGenerator.compile(cleanedSource)
312+
val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
313+
buffer.init(Array(leftIter, rightIter))
314+
new Iterator[InternalRow] {
315+
override def hasNext: Boolean = buffer.hasNext
316+
override def next: InternalRow = buffer.next()
317+
}
294318
}
295319
}
296320
}
297321

298-
override def upstream(): RDD[InternalRow] = {
322+
override def upstreams(): Seq[RDD[InternalRow]] = {
299323
throw new UnsupportedOperationException
300324
}
301325

@@ -312,7 +336,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
312336
if (row != null) {
313337
// There is an UnsafeRow already
314338
s"""
315-
| currentRows.add($row.copy());
339+
|append($row.copy());
316340
""".stripMargin
317341
} else {
318342
assert(input != null)
@@ -324,13 +348,13 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
324348
ctx.currentVars = input
325349
val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
326350
s"""
327-
| ${code.code.trim}
328-
| currentRows.add(${code.value}.copy());
351+
|${code.code.trim}
352+
|append(${code.value}.copy());
329353
""".stripMargin
330354
} else {
331355
// There is no columns
332356
s"""
333-
| currentRows.add(unsafeRow);
357+
|append(unsafeRow);
334358
""".stripMargin
335359
}
336360
}
@@ -402,6 +426,9 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru
402426
b.copy(left = apply(left))
403427
case b @ BroadcastHashJoin(_, _, _, BuildRight, _, left, right) =>
404428
b.copy(right = apply(right))
429+
case j @ SortMergeJoin(_, _, _, left, right) =>
430+
// The children of SortMergeJoin should do codegen separately.
431+
j.copy(left = apply(left), right = apply(right))
405432
case p if !supportCodegen(p) =>
406433
val input = apply(p) // collapse them recursively
407434
inputs += input

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,8 @@ case class TungstenAggregate(
121121
!aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate])
122122
}
123123

124-
override def upstream(): RDD[InternalRow] = {
125-
child.asInstanceOf[CodegenSupport].upstream()
124+
override def upstreams(): Seq[RDD[InternalRow]] = {
125+
child.asInstanceOf[CodegenSupport].upstreams()
126126
}
127127

128128
protected override def doProduce(ctx: CodegenContext): String = {

sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan)
3131

3232
override def output: Seq[Attribute] = projectList.map(_.toAttribute)
3333

34-
override def upstream(): RDD[InternalRow] = {
35-
child.asInstanceOf[CodegenSupport].upstream()
34+
override def upstreams(): Seq[RDD[InternalRow]] = {
35+
child.asInstanceOf[CodegenSupport].upstreams()
3636
}
3737

3838
protected override def doProduce(ctx: CodegenContext): String = {
@@ -69,8 +69,8 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode wit
6969
private[sql] override lazy val metrics = Map(
7070
"numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
7171

72-
override def upstream(): RDD[InternalRow] = {
73-
child.asInstanceOf[CodegenSupport].upstream()
72+
override def upstreams(): Seq[RDD[InternalRow]] = {
73+
child.asInstanceOf[CodegenSupport].upstreams()
7474
}
7575

7676
protected override def doProduce(ctx: CodegenContext): String = {
@@ -156,8 +156,9 @@ case class Range(
156156
private[sql] override lazy val metrics = Map(
157157
"numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
158158

159-
override def upstream(): RDD[InternalRow] = {
160-
sqlContext.sparkContext.parallelize(0 until numSlices, numSlices).map(i => InternalRow(i))
159+
override def upstreams(): Seq[RDD[InternalRow]] = {
160+
sqlContext.sparkContext.parallelize(0 until numSlices, numSlices)
161+
.map(i => InternalRow(i)) :: Nil
161162
}
162163

163164
protected override def doProduce(ctx: CodegenContext): String = {
@@ -213,12 +214,15 @@ case class Range(
213214
| }
214215
""".stripMargin)
215216

217+
val input = ctx.freshName("input")
218+
// Right now, Range is only used when there is one upstream.
219+
ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];")
216220
s"""
217221
| // initialize Range
218222
| if (!$initTerm) {
219223
| $initTerm = true;
220-
| if (input.hasNext()) {
221-
| initRange(((InternalRow) input.next()).getInt(0));
224+
| if ($input.hasNext()) {
225+
| initRange(((InternalRow) $input.next()).getInt(0));
222226
| } else {
223227
| return;
224228
| }

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,8 @@ case class BroadcastHashJoin(
9999
}
100100
}
101101

102-
override def upstream(): RDD[InternalRow] = {
103-
streamedPlan.asInstanceOf[CodegenSupport].upstream()
102+
override def upstreams(): Seq[RDD[InternalRow]] = {
103+
streamedPlan.asInstanceOf[CodegenSupport].upstreams()
104104
}
105105

106106
override def doProduce(ctx: CodegenContext): String = {

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ import org.apache.spark.sql.execution.metric.SQLMetrics
2727
import org.apache.spark.util.CompletionIterator
2828
import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
2929

30-
3130
/**
3231
* An optimized CartesianRDD for UnsafeRow, which will cache the rows from second child RDD,
3332
* will be much faster than building the right partition for every row in left RDD, it also

0 commit comments

Comments
 (0)