Skip to content

Commit 038b185

Browse files
juliuszsompolskigatorsmile
authored andcommitted
[SPARK-22103] Move HashAggregateExec parent consume to a separate function in codegen
## What changes were proposed in this pull request? HashAggregateExec codegen uses two paths for fast hash table and a generic one. It generates code paths for iterating over both, and both code paths generate the consume code of the parent operator, resulting in that code being expanded twice. This leads to a long generated function that might be an issue for the compiler (see e.g. SPARK-21603). I propose to remove the double expansion by generating the consume code in a helper function that can just be called from both iterating loops. An issue with separating the `consume` code to a helper function was that a number of places relied and assumed on being in the scope of an outside `produce` loop and e.g. use `continue` to jump out. I replaced such code flows with nested scopes. It is code that should be handled the same by compiler, while getting rid of depending on assumptions that are outside of the `consume`'s own scope. ## How was this patch tested? Existing test coverage. Author: Juliusz Sompolski <[email protected]> Closes #19324 from juliuszsompolski/aggrconsumecodegen.
1 parent 2c5b9b1 commit 038b185

File tree

5 files changed

+164
-89
lines changed

5 files changed

+164
-89
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,9 @@ class CodegenContext {
242242
private val classFunctions: mutable.Map[String, mutable.Map[String, String]] =
243243
mutable.Map(outerClassName -> mutable.Map.empty[String, String])
244244

245+
// Verbatim extra code to be added to the OuterClass.
246+
private val extraCode: mutable.ListBuffer[String] = mutable.ListBuffer[String]()
247+
245248
// Returns the size of the most recently added class.
246249
private def currClassSize(): Int = classSize(classes.head._1)
247250

@@ -328,6 +331,22 @@ class CodegenContext {
328331
(inlinedFunctions ++ initNestedClasses ++ declareNestedClasses).mkString("\n")
329332
}
330333

334+
/**
335+
* Emits any source code added with addExtraCode
336+
*/
337+
def emitExtraCode(): String = {
338+
extraCode.mkString("\n")
339+
}
340+
341+
/**
342+
* Add extra source code to the outermost generated class.
343+
* @param code verbatim source code to be added.
344+
*/
345+
def addExtraCode(code: String): Unit = {
346+
extraCode.append(code)
347+
classSize(outerClassName) += code.length
348+
}
349+
331350
final val JAVA_BOOLEAN = "boolean"
332351
final val JAVA_BYTE = "byte"
333352
final val JAVA_SHORT = "short"

sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -197,11 +197,14 @@ trait CodegenSupport extends SparkPlan {
197197
*
198198
* This should be override by subclass to support codegen.
199199
*
200-
* For example, Filter will generate the code like this:
200+
* Note: The operator should not assume the existence of an outer processing loop,
201+
* which it can jump from with "continue;"!
201202
*
203+
* For example, filter could generate this:
202204
* # code to evaluate the predicate expression, result is isNull1 and value2
203-
* if (isNull1 || !value2) continue;
204-
* # call consume(), which will call parent.doConsume()
205+
* if (!isNull1 && value2) {
206+
* # call consume(), which will call parent.doConsume()
207+
* }
205208
*
206209
* Note: A plan can either consume the rows as UnsafeRow (row), or a list of variables (input).
207210
*/
@@ -329,6 +332,15 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co
329332
def doCodeGen(): (CodegenContext, CodeAndComment) = {
330333
val ctx = new CodegenContext
331334
val code = child.asInstanceOf[CodegenSupport].produce(ctx, this)
335+
336+
// main next function.
337+
ctx.addNewFunction("processNext",
338+
s"""
339+
protected void processNext() throws java.io.IOException {
340+
${code.trim}
341+
}
342+
""", inlineToOuterClass = true)
343+
332344
val source = s"""
333345
public Object generate(Object[] references) {
334346
return new GeneratedIterator(references);
@@ -352,9 +364,7 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co
352364
${ctx.initPartition()}
353365
}
354366

355-
protected void processNext() throws java.io.IOException {
356-
${code.trim}
357-
}
367+
${ctx.emitExtraCode()}
358368

359369
${ctx.declareAddedFunctions()}
360370
}

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala

Lines changed: 60 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -425,12 +425,14 @@ case class HashAggregateExec(
425425

426426
/**
427427
* Generate the code for output.
428+
* @return function name for the result code.
428429
*/
429-
private def generateResultCode(
430-
ctx: CodegenContext,
431-
keyTerm: String,
432-
bufferTerm: String,
433-
plan: String): String = {
430+
private def generateResultFunction(ctx: CodegenContext): String = {
431+
val funcName = ctx.freshName("doAggregateWithKeysOutput")
432+
val keyTerm = ctx.freshName("keyTerm")
433+
val bufferTerm = ctx.freshName("bufferTerm")
434+
435+
val body =
434436
if (modes.contains(Final) || modes.contains(Complete)) {
435437
// generate output using resultExpressions
436438
ctx.currentVars = null
@@ -462,18 +464,36 @@ case class HashAggregateExec(
462464
$evaluateAggResults
463465
${consume(ctx, resultVars)}
464466
"""
465-
466467
} else if (modes.contains(Partial) || modes.contains(PartialMerge)) {
467-
// This should be the last operator in a stage, we should output UnsafeRow directly
468-
val joinerTerm = ctx.freshName("unsafeRowJoiner")
469-
ctx.addMutableState(classOf[UnsafeRowJoiner].getName, joinerTerm,
470-
s"$joinerTerm = $plan.createUnsafeJoiner();")
471-
val resultRow = ctx.freshName("resultRow")
468+
// resultExpressions are Attributes of groupingExpressions and aggregateBufferAttributes.
469+
assert(resultExpressions.forall(_.isInstanceOf[Attribute]))
470+
assert(resultExpressions.length ==
471+
groupingExpressions.length + aggregateBufferAttributes.length)
472+
473+
ctx.currentVars = null
474+
475+
ctx.INPUT_ROW = keyTerm
476+
val keyVars = groupingExpressions.zipWithIndex.map { case (e, i) =>
477+
BoundReference(i, e.dataType, e.nullable).genCode(ctx)
478+
}
479+
val evaluateKeyVars = evaluateVariables(keyVars)
480+
481+
ctx.INPUT_ROW = bufferTerm
482+
val resultBufferVars = aggregateBufferAttributes.zipWithIndex.map { case (e, i) =>
483+
BoundReference(i, e.dataType, e.nullable).genCode(ctx)
484+
}
485+
val evaluateResultBufferVars = evaluateVariables(resultBufferVars)
486+
487+
ctx.currentVars = keyVars ++ resultBufferVars
488+
val inputAttrs = resultExpressions.map(_.toAttribute)
489+
val resultVars = resultExpressions.map { e =>
490+
BindReferences.bindReference(e, inputAttrs).genCode(ctx)
491+
}
472492
s"""
473-
UnsafeRow $resultRow = $joinerTerm.join($keyTerm, $bufferTerm);
474-
${consume(ctx, null, resultRow)}
493+
$evaluateKeyVars
494+
$evaluateResultBufferVars
495+
${consume(ctx, resultVars)}
475496
"""
476-
477497
} else {
478498
// generate result based on grouping key
479499
ctx.INPUT_ROW = keyTerm
@@ -483,6 +503,13 @@ case class HashAggregateExec(
483503
}
484504
consume(ctx, eval)
485505
}
506+
ctx.addNewFunction(funcName,
507+
s"""
508+
private void $funcName(UnsafeRow $keyTerm, UnsafeRow $bufferTerm)
509+
throws java.io.IOException {
510+
$body
511+
}
512+
""")
486513
}
487514

488515
/**
@@ -581,11 +608,6 @@ case class HashAggregateExec(
581608
val iterTerm = ctx.freshName("mapIter")
582609
ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, iterTerm, "")
583610

584-
val doAgg = ctx.freshName("doAggregateWithKeys")
585-
val peakMemory = metricTerm(ctx, "peakMemory")
586-
val spillSize = metricTerm(ctx, "spillSize")
587-
val avgHashProbe = metricTerm(ctx, "avgHashProbe")
588-
589611
def generateGenerateCode(): String = {
590612
if (isFastHashMapEnabled) {
591613
if (isVectorizedHashMapEnabled) {
@@ -599,10 +621,14 @@ case class HashAggregateExec(
599621
}
600622
} else ""
601623
}
624+
ctx.addExtraCode(generateGenerateCode())
602625

626+
val doAgg = ctx.freshName("doAggregateWithKeys")
627+
val peakMemory = metricTerm(ctx, "peakMemory")
628+
val spillSize = metricTerm(ctx, "spillSize")
629+
val avgHashProbe = metricTerm(ctx, "avgHashProbe")
603630
val doAggFuncName = ctx.addNewFunction(doAgg,
604631
s"""
605-
${generateGenerateCode}
606632
private void $doAgg() throws java.io.IOException {
607633
$hashMapTerm = $thisPlan.createHashMap();
608634
${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
@@ -618,7 +644,7 @@ case class HashAggregateExec(
618644
// generate code for output
619645
val keyTerm = ctx.freshName("aggKey")
620646
val bufferTerm = ctx.freshName("aggBuffer")
621-
val outputCode = generateResultCode(ctx, keyTerm, bufferTerm, thisPlan)
647+
val outputFunc = generateResultFunction(ctx)
622648
val numOutput = metricTerm(ctx, "numOutputRows")
623649

624650
// The child could change `copyResult` to true, but we had already consumed all the rows,
@@ -641,7 +667,7 @@ case class HashAggregateExec(
641667
$numOutput.add(1);
642668
UnsafeRow $keyTerm = (UnsafeRow) $iterTermForFastHashMap.getKey();
643669
UnsafeRow $bufferTerm = (UnsafeRow) $iterTermForFastHashMap.getValue();
644-
$outputCode
670+
$outputFunc($keyTerm, $bufferTerm);
645671

646672
if (shouldStop()) return;
647673
}
@@ -654,18 +680,23 @@ case class HashAggregateExec(
654680
val row = ctx.freshName("fastHashMapRow")
655681
ctx.currentVars = null
656682
ctx.INPUT_ROW = row
657-
var schema: StructType = groupingKeySchema
658-
bufferSchema.foreach(i => schema = schema.add(i))
659-
val generateRow = GenerateUnsafeProjection.createCode(ctx, schema.toAttributes.zipWithIndex
660-
.map { case (attr, i) => BoundReference(i, attr.dataType, attr.nullable) })
683+
val generateKeyRow = GenerateUnsafeProjection.createCode(ctx,
684+
groupingKeySchema.toAttributes.zipWithIndex
685+
.map { case (attr, i) => BoundReference(i, attr.dataType, attr.nullable) }
686+
)
687+
val generateBufferRow = GenerateUnsafeProjection.createCode(ctx,
688+
bufferSchema.toAttributes.zipWithIndex
689+
.map { case (attr, i) =>
690+
BoundReference(groupingKeySchema.length + i, attr.dataType, attr.nullable) })
661691
s"""
662692
| while ($iterTermForFastHashMap.hasNext()) {
663693
| $numOutput.add(1);
664694
| org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row $row =
665695
| (org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row)
666696
| $iterTermForFastHashMap.next();
667-
| ${generateRow.code}
668-
| ${consume(ctx, Seq.empty, {generateRow.value})}
697+
| ${generateKeyRow.code}
698+
| ${generateBufferRow.code}
699+
| $outputFunc(${generateKeyRow.value}, ${generateBufferRow.value});
669700
|
670701
| if (shouldStop()) return;
671702
| }
@@ -692,7 +723,7 @@ case class HashAggregateExec(
692723
$numOutput.add(1);
693724
UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey();
694725
UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue();
695-
$outputCode
726+
$outputFunc($keyTerm, $bufferTerm);
696727

697728
if (shouldStop()) return;
698729
}

sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -201,11 +201,14 @@ case class FilterExec(condition: Expression, child: SparkPlan)
201201
ev
202202
}
203203

204+
// Note: wrap in "do { } while(false);", so the generated checks can jump out with "continue;"
204205
s"""
205-
|$generated
206-
|$nullChecks
207-
|$numOutput.add(1);
208-
|${consume(ctx, resultVars)}
206+
|do {
207+
| $generated
208+
| $nullChecks
209+
| $numOutput.add(1);
210+
| ${consume(ctx, resultVars)}
211+
|} while(false);
209212
""".stripMargin
210213
}
211214

@@ -316,9 +319,10 @@ case class SampleExec(
316319
""".stripMargin.trim)
317320

318321
s"""
319-
| if ($sampler.sample() == 0) continue;
320-
| $numOutput.add(1);
321-
| ${consume(ctx, input)}
322+
| if ($sampler.sample() != 0) {
323+
| $numOutput.add(1);
324+
| ${consume(ctx, input)}
325+
| }
322326
""".stripMargin.trim
323327
}
324328
}

0 commit comments

Comments
 (0)