@@ -262,37 +262,37 @@ case class HashAggregateExec(
262262 // Extracts all the input variable references for a given `aggExpr`. This result will be used
263263 // to split aggregation into small functions.
264264 private def getInputVariableReferences (
265- ctx : CodegenContext ,
266- aggExpr : Expression ,
265+ context : CodegenContext ,
266+ aggregateExpression : Expression ,
267267 subExprs : Map [Expression , SubExprEliminationState ]): Set [(String , String )] = {
268268 // `argSet` collects all the pairs of variable names and their types, the first in the pair is
269269 // a type name and the second is a variable name.
270270 val argSet = mutable.Set [(String , String )]()
271- val stack = mutable.Stack [Expression ](aggExpr )
271+ val stack = mutable.Stack [Expression ](aggregateExpression )
272272 while (stack.nonEmpty) {
273273 stack.pop() match {
274274 case e if subExprs.contains(e) =>
275275 val exprCode = subExprs(e)
276276 if (CodegenContext .isJavaIdentifier(exprCode.value)) {
277- argSet += ((ctx .javaType(e.dataType), exprCode.value))
277+ argSet += ((context .javaType(e.dataType), exprCode.value))
278278 }
279279 if (CodegenContext .isJavaIdentifier(exprCode.isNull)) {
280280 argSet += ((" boolean" , exprCode.isNull))
281281 }
282282 // Since the children possibly has common expressions, we push them here
283283 stack.pushAll(e.children)
284284 case ref : BoundReference
285- if ctx .currentVars != null && ctx .currentVars(ref.ordinal) != null =>
286- val value = ctx .currentVars(ref.ordinal).value
287- val isNull = ctx .currentVars(ref.ordinal).isNull
285+ if context .currentVars != null && context .currentVars(ref.ordinal) != null =>
286+ val value = context .currentVars(ref.ordinal).value
287+ val isNull = context .currentVars(ref.ordinal).isNull
288288 if (CodegenContext .isJavaIdentifier(value)) {
289- argSet += ((ctx .javaType(ref.dataType), value))
289+ argSet += ((context .javaType(ref.dataType), value))
290290 }
291291 if (CodegenContext .isJavaIdentifier(isNull)) {
292292 argSet += ((" boolean" , isNull))
293293 }
294294 case _ : BoundReference =>
295- argSet += ((" InternalRow" , ctx .INPUT_ROW ))
295+ argSet += ((" InternalRow" , context .INPUT_ROW ))
296296 case e =>
297297 stack.pushAll(e.children)
298298 }
@@ -303,30 +303,30 @@ case class HashAggregateExec(
303303
304304 // Splits aggregate code into small functions because JVMs does not compile too long functions
305305 private def splitAggregateExpressions (
306- ctx : CodegenContext ,
307- aggExprs : Seq [Expression ],
308- evalAndUpdateCodes : Seq [String ],
306+ context : CodegenContext ,
307+ aggregateExpressions : Seq [Expression ],
308+ codes : Seq [String ],
309309 subExprs : Map [Expression , SubExprEliminationState ],
310310 otherArgs : Seq [(String , String )] = Seq .empty): Seq [String ] = {
311- aggExprs .zipWithIndex.map { case (aggExpr, i) =>
312- val args = (getInputVariableReferences(ctx , aggExpr, subExprs) ++ otherArgs).toSeq
311+ aggregateExpressions .zipWithIndex.map { case (aggExpr, i) =>
312+ val args = (getInputVariableReferences(context , aggExpr, subExprs) ++ otherArgs).toSeq
313313
314314 // This method gives up splitting the code if the parameter length goes over
315315 // `maxParamNumInJavaMethod`.
316316 if (args.size <= sqlContext.conf.maxParamNumInJavaMethod) {
317- val doAggVal = ctx .freshName(s " doAggregateVal_ ${aggExpr.prettyName}" )
317+ val doAggVal = context .freshName(s " doAggregateVal_ ${aggExpr.prettyName}" )
318318 val argList = args.map(a => s " ${a._1} ${a._2}" ).mkString(" , " )
319- val doAggValFuncName = ctx .addNewFunction(doAggVal,
319+ val doAggValFuncName = context .addNewFunction(doAggVal,
320320 s """
321321 | private void $doAggVal( $argList) throws java.io.IOException {
322- | ${evalAndUpdateCodes (i)}
322+ | ${codes (i)}
323323 | }
324324 """ .stripMargin)
325325
326326 val inputVariables = args.map(_._2).mkString(" , " )
327327 s " $doAggValFuncName( $inputVariables); "
328328 } else {
329- evalAndUpdateCodes (i)
329+ codes (i)
330330 }
331331 }
332332 }
@@ -377,7 +377,10 @@ case class HashAggregateExec(
377377 }
378378
379379 val updateAggValCode = splitAggregateExpressions(
380- ctx, boundUpdateExpr, evalAndUpdateCodes, subExprs.states)
380+ context = ctx,
381+ aggregateExpressions = boundUpdateExpr,
382+ codes = evalAndUpdateCodes,
383+ subExprs = subExprs.states)
381384
382385 s """
383386 | // do aggregate
@@ -946,11 +949,11 @@ case class HashAggregateExec(
946949 }
947950
948951 val updateAggValCode = splitAggregateExpressions(
949- ctx,
950- boundUpdateExpr,
951- evalAndUpdateCodes,
952- subExprs.states,
953- Seq ((" InternalRow" , unsafeRowBuffer)))
952+ context = ctx,
953+ aggregateExpressions = boundUpdateExpr,
954+ codes = evalAndUpdateCodes,
955+ subExprs = subExprs .states,
956+ otherArgs = Seq ((" InternalRow" , unsafeRowBuffer)))
954957
955958 s """
956959 | // do aggregate
@@ -991,11 +994,11 @@ case class HashAggregateExec(
991994 }
992995
993996 val updateAggValCode = splitAggregateExpressions(
994- ctx,
995- boundUpdateExpr,
996- evalAndUpdateCodes,
997- subExprs.states,
998- Seq ((" InternalRow" , fastRowBuffer)))
997+ context = ctx,
998+ aggregateExpressions = boundUpdateExpr,
999+ codes = evalAndUpdateCodes,
1000+ subExprs = subExprs .states,
1001+ otherArgs = Seq ((" InternalRow" , fastRowBuffer)))
9991002
10001003 // If fast hash map is on, we first generate code to update row in fast hash map, if the
10011004 // previous loop up hit fast hash map. Otherwise, update row in regular hash map.
0 commit comments