@@ -425,12 +425,14 @@ case class HashAggregateExec(
425425
426426 /**
427427 * Generate the code for output.
428+ * @return function name for the result code.
428429 */
429- private def generateResultCode (
430- ctx : CodegenContext ,
431- keyTerm : String ,
432- bufferTerm : String ,
433- plan : String ): String = {
430+ private def generateResultFunction (ctx : CodegenContext ): String = {
431+ val funcName = ctx.freshName(" doAggregateWithKeysOutput" )
432+ val keyTerm = ctx.freshName(" keyTerm" )
433+ val bufferTerm = ctx.freshName(" bufferTerm" )
434+
435+ val body =
434436 if (modes.contains(Final ) || modes.contains(Complete )) {
435437 // generate output using resultExpressions
436438 ctx.currentVars = null
@@ -462,18 +464,36 @@ case class HashAggregateExec(
462464 $evaluateAggResults
463465 ${consume(ctx, resultVars)}
464466 """
465-
466467 } else if (modes.contains(Partial ) || modes.contains(PartialMerge )) {
467- // This should be the last operator in a stage, we should output UnsafeRow directly
468- val joinerTerm = ctx.freshName(" unsafeRowJoiner" )
469- ctx.addMutableState(classOf [UnsafeRowJoiner ].getName, joinerTerm,
470- s " $joinerTerm = $plan.createUnsafeJoiner(); " )
471- val resultRow = ctx.freshName(" resultRow" )
468+ // resultExpressions are Attributes of groupingExpressions and aggregateBufferAttributes.
469+ assert(resultExpressions.forall(_.isInstanceOf [Attribute ]))
470+ assert(resultExpressions.length ==
471+ groupingExpressions.length + aggregateBufferAttributes.length)
472+
473+ ctx.currentVars = null
474+
475+ ctx.INPUT_ROW = keyTerm
476+ val keyVars = groupingExpressions.zipWithIndex.map { case (e, i) =>
477+ BoundReference (i, e.dataType, e.nullable).genCode(ctx)
478+ }
479+ val evaluateKeyVars = evaluateVariables(keyVars)
480+
481+ ctx.INPUT_ROW = bufferTerm
482+ val resultBufferVars = aggregateBufferAttributes.zipWithIndex.map { case (e, i) =>
483+ BoundReference (i, e.dataType, e.nullable).genCode(ctx)
484+ }
485+ val evaluateResultBufferVars = evaluateVariables(resultBufferVars)
486+
487+ ctx.currentVars = keyVars ++ resultBufferVars
488+ val inputAttrs = resultExpressions.map(_.toAttribute)
489+ val resultVars = resultExpressions.map { e =>
490+ BindReferences .bindReference(e, inputAttrs).genCode(ctx)
491+ }
472492 s """
473- UnsafeRow $resultRow = $joinerTerm.join( $keyTerm, $bufferTerm);
474- ${consume(ctx, null , resultRow)}
493+ $evaluateKeyVars
494+ $evaluateResultBufferVars
495+ ${consume(ctx, resultVars)}
475496 """
476-
477497 } else {
478498 // generate result based on grouping key
479499 ctx.INPUT_ROW = keyTerm
@@ -483,6 +503,13 @@ case class HashAggregateExec(
483503 }
484504 consume(ctx, eval)
485505 }
506+ ctx.addNewFunction(funcName,
507+ s """
508+ private void $funcName(UnsafeRow $keyTerm, UnsafeRow $bufferTerm)
509+ throws java.io.IOException {
510+ $body
511+ }
512+ """ )
486513 }
487514
488515 /**
@@ -581,11 +608,6 @@ case class HashAggregateExec(
581608 val iterTerm = ctx.freshName(" mapIter" )
582609 ctx.addMutableState(classOf [KVIterator [UnsafeRow , UnsafeRow ]].getName, iterTerm, " " )
583610
584- val doAgg = ctx.freshName(" doAggregateWithKeys" )
585- val peakMemory = metricTerm(ctx, " peakMemory" )
586- val spillSize = metricTerm(ctx, " spillSize" )
587- val avgHashProbe = metricTerm(ctx, " avgHashProbe" )
588-
589611 def generateGenerateCode (): String = {
590612 if (isFastHashMapEnabled) {
591613 if (isVectorizedHashMapEnabled) {
@@ -599,10 +621,14 @@ case class HashAggregateExec(
599621 }
600622 } else " "
601623 }
624+ ctx.addExtraCode(generateGenerateCode())
602625
626+ val doAgg = ctx.freshName(" doAggregateWithKeys" )
627+ val peakMemory = metricTerm(ctx, " peakMemory" )
628+ val spillSize = metricTerm(ctx, " spillSize" )
629+ val avgHashProbe = metricTerm(ctx, " avgHashProbe" )
603630 val doAggFuncName = ctx.addNewFunction(doAgg,
604631 s """
605- ${generateGenerateCode}
606632 private void $doAgg() throws java.io.IOException {
607633 $hashMapTerm = $thisPlan.createHashMap();
608634 ${child.asInstanceOf [CodegenSupport ].produce(ctx, this )}
@@ -618,7 +644,7 @@ case class HashAggregateExec(
618644 // generate code for output
619645 val keyTerm = ctx.freshName(" aggKey" )
620646 val bufferTerm = ctx.freshName(" aggBuffer" )
621- val outputCode = generateResultCode (ctx, keyTerm, bufferTerm, thisPlan )
647+ val outputFunc = generateResultFunction (ctx)
622648 val numOutput = metricTerm(ctx, " numOutputRows" )
623649
624650 // The child could change `copyResult` to true, but we had already consumed all the rows,
@@ -641,7 +667,7 @@ case class HashAggregateExec(
641667 $numOutput.add(1);
642668 UnsafeRow $keyTerm = (UnsafeRow) $iterTermForFastHashMap.getKey();
643669 UnsafeRow $bufferTerm = (UnsafeRow) $iterTermForFastHashMap.getValue();
644- $outputCode
670+ $outputFunc ( $keyTerm , $bufferTerm );
645671
646672 if (shouldStop()) return;
647673 }
@@ -654,18 +680,23 @@ case class HashAggregateExec(
654680 val row = ctx.freshName(" fastHashMapRow" )
655681 ctx.currentVars = null
656682 ctx.INPUT_ROW = row
657- var schema : StructType = groupingKeySchema
658- bufferSchema.foreach(i => schema = schema.add(i))
659- val generateRow = GenerateUnsafeProjection .createCode(ctx, schema.toAttributes.zipWithIndex
660- .map { case (attr, i) => BoundReference (i, attr.dataType, attr.nullable) })
683+ val generateKeyRow = GenerateUnsafeProjection .createCode(ctx,
684+ groupingKeySchema.toAttributes.zipWithIndex
685+ .map { case (attr, i) => BoundReference (i, attr.dataType, attr.nullable) }
686+ )
687+ val generateBufferRow = GenerateUnsafeProjection .createCode(ctx,
688+ bufferSchema.toAttributes.zipWithIndex
689+ .map { case (attr, i) =>
690+ BoundReference (groupingKeySchema.length + i, attr.dataType, attr.nullable) })
661691 s """
662692 | while ( $iterTermForFastHashMap.hasNext()) {
663693 | $numOutput.add(1);
664694 | org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row $row =
665695 | (org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row)
666696 | $iterTermForFastHashMap.next();
667- | ${generateRow.code}
668- | ${consume(ctx, Seq .empty, {generateRow.value})}
697+ | ${generateKeyRow.code}
698+ | ${generateBufferRow.code}
699+ | $outputFunc( ${generateKeyRow.value}, ${generateBufferRow.value});
669700 |
670701 | if (shouldStop()) return;
671702 | }
@@ -692,7 +723,7 @@ case class HashAggregateExec(
692723 $numOutput.add(1);
693724 UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey();
694725 UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue();
695- $outputCode
726+ $outputFunc ( $keyTerm , $bufferTerm );
696727
697728 if (shouldStop()) return;
698729 }
0 commit comments