Skip to content

Commit b3669f2

Browse files
author
Davies Liu
committed
generate sort merge join
1 parent 892b2dd commit b3669f2

File tree

10 files changed

+290
-52
lines changed

10 files changed

+290
-52
lines changed

core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ private[spark] class DiskBlockObjectWriter(
9797
override def close() {
9898
if (initialized) {
9999
Utils.tryWithSafeFinally {
100+
updateBytesWritten()
100101
if (syncWrites) {
101102
// Force outstanding writes to disk and track how long it takes
102103
objOut.flush()
@@ -203,7 +204,7 @@ private[spark] class DiskBlockObjectWriter(
203204
numRecordsWritten += 1
204205
writeMetrics.incRecordsWritten(1)
205206

206-
if (numRecordsWritten % 32 == 0) {
207+
if (numRecordsWritten % 1024 == 0) {
207208
updateBytesWritten()
208209
}
209210
}

sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,9 @@
2929
/**
3030
* An iterator interface used to pull the output from generated function for multiple operators
3131
* (whole stage codegen).
32-
*
33-
* TODO: replaced it by batched columnar format.
3432
*/
35-
public class BufferedRowIterator {
33+
public abstract class BufferedRowIterator {
3634
protected LinkedList<InternalRow> currentRows = new LinkedList<>();
37-
protected Iterator<InternalRow> input;
3835
// used when there is no column in output
3936
protected UnsafeRow unsafeRow = new UnsafeRow(0);
4037

@@ -49,8 +46,16 @@ public InternalRow next() {
4946
return currentRows.remove();
5047
}
5148

52-
public void setInput(Iterator<InternalRow> iter) {
53-
input = iter;
49+
/**
50+
* Initializes from array of iterators of InternalRow.
51+
*/
52+
public abstract void init(Iterator<InternalRow> iters[]);
53+
54+
/**
55+
* Append a row to currentRows.
56+
*/
57+
protected void append(InternalRow row) {
58+
currentRows.add(row);
5459
}
5560

5661
/**
@@ -74,9 +79,5 @@ protected void incPeakExecutionMemory(long size) {
7479
*
7580
* After it's called, if currentRow is still null, it means no more rows left.
7681
*/
77-
protected void processNext() throws IOException {
78-
if (input.hasNext()) {
79-
currentRows.add(input.next());
80-
}
81-
}
82+
protected abstract void processNext() throws IOException;
8283
}

sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ case class Expand(
8787
}
8888
}
8989

90-
override def upstream(): RDD[InternalRow] = {
91-
child.asInstanceOf[CodegenSupport].upstream()
90+
override def upstreams(): Seq[RDD[InternalRow]] = {
91+
child.asInstanceOf[CodegenSupport].upstreams()
9292
}
9393

9494
protected override def doProduce(ctx: CodegenContext): String = {

sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala

Lines changed: 48 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
2727
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
2828
import org.apache.spark.sql.catalyst.rules.Rule
2929
import org.apache.spark.sql.execution.aggregate.TungstenAggregate
30-
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, BuildLeft, BuildRight}
31-
import org.apache.spark.sql.execution.metric.{LongSQLMetric, LongSQLMetricValue, SQLMetric}
30+
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, BuildLeft, BuildRight, SortMergeJoin}
31+
import org.apache.spark.sql.execution.metric.LongSQLMetricValue
3232

3333
/**
3434
* An interface for those physical operators that support codegen.
@@ -39,6 +39,7 @@ trait CodegenSupport extends SparkPlan {
3939
private def variablePrefix: String = this match {
4040
case _: TungstenAggregate => "agg"
4141
case _: BroadcastHashJoin => "bhj"
42+
case _: SortMergeJoin => "smj"
4243
case _ => nodeName.toLowerCase
4344
}
4445

@@ -66,9 +67,11 @@ trait CodegenSupport extends SparkPlan {
6667
private var parent: CodegenSupport = null
6768

6869
/**
69-
* Returns the RDD of InternalRow which generates the input rows.
70+
* Returns all the RDDs of InternalRow which generates the input rows.
71+
*
72+
* Note: right now we support up to two RDDs.
7073
*/
71-
def upstream(): RDD[InternalRow]
74+
def upstreams(): Seq[RDD[InternalRow]]
7275

7376
/**
7477
* Returns Java source code to process the rows from upstream.
@@ -172,19 +175,23 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
172175

173176
override def supportCodegen: Boolean = false
174177

175-
override def upstream(): RDD[InternalRow] = {
176-
child.execute()
178+
override def upstreams(): Seq[RDD[InternalRow]] = {
179+
child.execute() :: Nil
177180
}
178181

179182
override def doProduce(ctx: CodegenContext): String = {
183+
val input = ctx.freshName("input")
184+
// Right now, Range is only used when there is one upstream.
185+
ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];")
186+
180187
val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true))
181188
val row = ctx.freshName("row")
182189
ctx.INPUT_ROW = row
183190
ctx.currentVars = null
184191
val columns = exprs.map(_.gen(ctx))
185192
s"""
186-
| while (input.hasNext()) {
187-
| InternalRow $row = (InternalRow) input.next();
193+
| while ($input.hasNext()) {
194+
| InternalRow $row = (InternalRow) $input.next();
188195
| ${columns.map(_.code).mkString("\n").trim}
189196
| ${consume(ctx, columns).trim}
190197
| if (shouldStop()) {
@@ -208,7 +215,7 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
208215
*
209216
* -> execute()
210217
* |
211-
* doExecute() ---------> upstream() -------> upstream() ------> execute()
218+
* doExecute() ---------> upstreams() -------> upstreams() ------> execute()
212219
* |
213220
* -----------------> produce()
214221
* |
@@ -260,6 +267,9 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
260267

261268
public GeneratedIterator(Object[] references) {
262269
this.references = references;
270+
}
271+
272+
public void init(scala.collection.Iterator inputs[]) {
263273
${ctx.initMutableStates()}
264274
}
265275

@@ -276,19 +286,32 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
276286
// println(s"${CodeFormatter.format(cleanedSource)}")
277287
CodeGenerator.compile(cleanedSource)
278288

279-
plan.upstream().mapPartitions { iter =>
280-
281-
val clazz = CodeGenerator.compile(source)
282-
val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
283-
buffer.setInput(iter)
284-
new Iterator[InternalRow] {
285-
override def hasNext: Boolean = buffer.hasNext
286-
override def next: InternalRow = buffer.next()
289+
val rdds = plan.upstreams()
290+
if (rdds.length == 1) {
291+
rdds.head.mapPartitions { iter =>
292+
val clazz = CodeGenerator.compile(cleanedSource)
293+
val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
294+
buffer.init(Array(iter))
295+
new Iterator[InternalRow] {
296+
override def hasNext: Boolean = buffer.hasNext
297+
override def next: InternalRow = buffer.next()
298+
}
299+
}
300+
} else {
301+
// Right now, we support up to two upstreams.
302+
rdds.head.zipPartitions(rdds(1)) { (leftIter, rightIter) =>
303+
val clazz = CodeGenerator.compile(cleanedSource)
304+
val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
305+
buffer.init(Array(leftIter, rightIter))
306+
new Iterator[InternalRow] {
307+
override def hasNext: Boolean = buffer.hasNext
308+
override def next: InternalRow = buffer.next()
309+
}
287310
}
288311
}
289312
}
290313

291-
override def upstream(): RDD[InternalRow] = {
314+
override def upstreams(): Seq[RDD[InternalRow]] = {
292315
throw new UnsupportedOperationException
293316
}
294317

@@ -305,7 +328,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
305328
if (row != null) {
306329
// There is an UnsafeRow already
307330
s"""
308-
| currentRows.add($row.copy());
331+
|append($row.copy());
309332
""".stripMargin
310333
} else {
311334
assert(input != null)
@@ -317,13 +340,13 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
317340
ctx.currentVars = input
318341
val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
319342
s"""
320-
| ${code.code.trim}
321-
| currentRows.add(${code.value}.copy());
343+
|${code.code.trim}
344+
|append(${code.value}.copy());
322345
""".stripMargin
323346
} else {
324347
// There is no columns
325348
s"""
326-
| currentRows.add(unsafeRow);
349+
|append(unsafeRow);
327350
""".stripMargin
328351
}
329352
}
@@ -395,6 +418,9 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru
395418
b.copy(left = apply(left))
396419
case b @ BroadcastHashJoin(_, _, BuildRight, _, left, right) =>
397420
b.copy(right = apply(right))
421+
case j @ SortMergeJoin(_, _, _, left, right) =>
422+
// The children of SortMergeJoin should do codegen separately.
423+
j.copy(left = apply(left), right = apply(right))
398424
case p if !supportCodegen(p) =>
399425
val input = apply(p) // collapse them recursively
400426
inputs += input

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,8 @@ case class TungstenAggregate(
121121
!aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate])
122122
}
123123

124-
override def upstream(): RDD[InternalRow] = {
125-
child.asInstanceOf[CodegenSupport].upstream()
124+
override def upstreams(): Seq[RDD[InternalRow]] = {
125+
child.asInstanceOf[CodegenSupport].upstreams()
126126
}
127127

128128
protected override def doProduce(ctx: CodegenContext): String = {

sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan)
3131

3232
override def output: Seq[Attribute] = projectList.map(_.toAttribute)
3333

34-
override def upstream(): RDD[InternalRow] = {
35-
child.asInstanceOf[CodegenSupport].upstream()
34+
override def upstreams(): Seq[RDD[InternalRow]] = {
35+
child.asInstanceOf[CodegenSupport].upstreams()
3636
}
3737

3838
protected override def doProduce(ctx: CodegenContext): String = {
@@ -69,8 +69,8 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode wit
6969
private[sql] override lazy val metrics = Map(
7070
"numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
7171

72-
override def upstream(): RDD[InternalRow] = {
73-
child.asInstanceOf[CodegenSupport].upstream()
72+
override def upstreams(): Seq[RDD[InternalRow]] = {
73+
child.asInstanceOf[CodegenSupport].upstreams()
7474
}
7575

7676
protected override def doProduce(ctx: CodegenContext): String = {
@@ -156,8 +156,9 @@ case class Range(
156156
private[sql] override lazy val metrics = Map(
157157
"numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
158158

159-
override def upstream(): RDD[InternalRow] = {
160-
sqlContext.sparkContext.parallelize(0 until numSlices, numSlices).map(i => InternalRow(i))
159+
override def upstreams(): Seq[RDD[InternalRow]] = {
160+
sqlContext.sparkContext.parallelize(0 until numSlices, numSlices)
161+
.map(i => InternalRow(i)) :: Nil
161162
}
162163

163164
protected override def doProduce(ctx: CodegenContext): String = {
@@ -213,12 +214,15 @@ case class Range(
213214
| }
214215
""".stripMargin)
215216

217+
val input = ctx.freshName("input")
218+
// Right now, Range is only used when there is one upstream.
219+
ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];")
216220
s"""
217221
| // initialize Range
218222
| if (!$initTerm) {
219223
| $initTerm = true;
220-
| if (input.hasNext()) {
221-
| initRange(((InternalRow) input.next()).getInt(0));
224+
| if ($input.hasNext()) {
225+
| initRange(((InternalRow) $input.next()).getInt(0));
222226
| } else {
223227
| return;
224228
| }

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,8 @@ case class BroadcastHashJoin(
115115
// the term for hash relation
116116
private var relationTerm: String = _
117117

118-
override def upstream(): RDD[InternalRow] = {
119-
streamedPlan.asInstanceOf[CodegenSupport].upstream()
118+
override def upstreams(): Seq[RDD[InternalRow]] = {
119+
streamedPlan.asInstanceOf[CodegenSupport].upstreams()
120120
}
121121

122122
override def doProduce(ctx: CodegenContext): String = {

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ import org.apache.spark.sql.execution.metric.SQLMetrics
2727
import org.apache.spark.util.CompletionIterator
2828
import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
2929

30-
3130
/**
3231
* An optimized CartesianRDD for UnsafeRow, which will cache the rows from second child RDD,
3332
* will be much faster than building the right partition for every row in left RDD, it also

0 commit comments

Comments
 (0)