diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 9cadd13999e72..10d3a183383db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -214,13 +214,32 @@ case class HashAggregateExec( val resultVars = resultExpressions.map { e => BindReferences.bindReference(e, aggregateAttributes).genCode(ctx) } - (resultVars, s""" + val localBufVars = resultExpressions.zip(resultVars).map { case (expr, ev) => + val isNull = ctx.freshName("localBufisNull") + val value = ctx.freshName("localBufValue") + (s""" + |boolean $isNull = ${ev.isNull};\n + |${ctx.javaType(expr.dataType)} $value = ${ev.value}; + """.stripMargin, + isNull, value) + } + (localBufVars.map(e => ExprCode("", e._2, e._3)), s""" |$evaluateAggResults |${evaluateVariables(resultVars)} + |${localBufVars.map(_._1).mkString("\n")} """.stripMargin) } else if (modes.contains(Partial) || modes.contains(PartialMerge)) { // output the aggregate buffer directly - (bufVars, "") + val localBufVars = initExpr.zip(bufVars).map { case (expr, ev) => + val isNull = ctx.freshName("localBufisNull") + val value = ctx.freshName("localBufValue") + (s""" + |boolean $isNull = ${ev.isNull};\n + |${ctx.javaType(expr.dataType)} $value = ${ev.value}; + """.stripMargin, + isNull, value) + } + (localBufVars.map(e => ExprCode("", e._2, e._3)), localBufVars.map(_._1).mkString("\n")) } else { // no aggregate function, the result should be literals val resultVars = resultExpressions.map(_.genCode(ctx))