Skip to content

Commit 2555e5a

Browse files
committed
Add parameter names
1 parent 0e5d366 commit 2555e5a

File tree

1 file changed

+32
-29
lines changed

1 file changed

+32
-29
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala

Lines changed: 32 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -262,37 +262,37 @@ case class HashAggregateExec(
262262
// Extracts all the input variable references for a given `aggExpr`. This result will be used
263263
// to split aggregation into small functions.
264264
private def getInputVariableReferences(
265-
ctx: CodegenContext,
266-
aggExpr: Expression,
265+
context: CodegenContext,
266+
aggregateExpression: Expression,
267267
subExprs: Map[Expression, SubExprEliminationState]): Set[(String, String)] = {
268268
// `argSet` collects all the pairs of variable names and their types, the first in the pair is
269269
// a type name and the second is a variable name.
270270
val argSet = mutable.Set[(String, String)]()
271-
val stack = mutable.Stack[Expression](aggExpr)
271+
val stack = mutable.Stack[Expression](aggregateExpression)
272272
while (stack.nonEmpty) {
273273
stack.pop() match {
274274
case e if subExprs.contains(e) =>
275275
val exprCode = subExprs(e)
276276
if (CodegenContext.isJavaIdentifier(exprCode.value)) {
277-
argSet += ((ctx.javaType(e.dataType), exprCode.value))
277+
argSet += ((context.javaType(e.dataType), exprCode.value))
278278
}
279279
if (CodegenContext.isJavaIdentifier(exprCode.isNull)) {
280280
argSet += (("boolean", exprCode.isNull))
281281
}
282282
// Since the children possibly has common expressions, we push them here
283283
stack.pushAll(e.children)
284284
case ref: BoundReference
285-
if ctx.currentVars != null && ctx.currentVars(ref.ordinal) != null =>
286-
val value = ctx.currentVars(ref.ordinal).value
287-
val isNull = ctx.currentVars(ref.ordinal).isNull
285+
if context.currentVars != null && context.currentVars(ref.ordinal) != null =>
286+
val value = context.currentVars(ref.ordinal).value
287+
val isNull = context.currentVars(ref.ordinal).isNull
288288
if (CodegenContext.isJavaIdentifier(value)) {
289-
argSet += ((ctx.javaType(ref.dataType), value))
289+
argSet += ((context.javaType(ref.dataType), value))
290290
}
291291
if (CodegenContext.isJavaIdentifier(isNull)) {
292292
argSet += (("boolean", isNull))
293293
}
294294
case _: BoundReference =>
295-
argSet += (("InternalRow", ctx.INPUT_ROW))
295+
argSet += (("InternalRow", context.INPUT_ROW))
296296
case e =>
297297
stack.pushAll(e.children)
298298
}
@@ -303,30 +303,30 @@ case class HashAggregateExec(
303303

304304
// Splits aggregate code into small functions because JVMs does not compile too long functions
305305
private def splitAggregateExpressions(
306-
ctx: CodegenContext,
307-
aggExprs: Seq[Expression],
308-
evalAndUpdateCodes: Seq[String],
306+
context: CodegenContext,
307+
aggregateExpressions: Seq[Expression],
308+
codes: Seq[String],
309309
subExprs: Map[Expression, SubExprEliminationState],
310310
otherArgs: Seq[(String, String)] = Seq.empty): Seq[String] = {
311-
aggExprs.zipWithIndex.map { case (aggExpr, i) =>
312-
val args = (getInputVariableReferences(ctx, aggExpr, subExprs) ++ otherArgs).toSeq
311+
aggregateExpressions.zipWithIndex.map { case (aggExpr, i) =>
312+
val args = (getInputVariableReferences(context, aggExpr, subExprs) ++ otherArgs).toSeq
313313

314314
// This method gives up splitting the code if the parameter length goes over
315315
// `maxParamNumInJavaMethod`.
316316
if (args.size <= sqlContext.conf.maxParamNumInJavaMethod) {
317-
val doAggVal = ctx.freshName(s"doAggregateVal_${aggExpr.prettyName}")
317+
val doAggVal = context.freshName(s"doAggregateVal_${aggExpr.prettyName}")
318318
val argList = args.map(a => s"${a._1} ${a._2}").mkString(", ")
319-
val doAggValFuncName = ctx.addNewFunction(doAggVal,
319+
val doAggValFuncName = context.addNewFunction(doAggVal,
320320
s"""
321321
| private void $doAggVal($argList) throws java.io.IOException {
322-
| ${evalAndUpdateCodes(i)}
322+
| ${codes(i)}
323323
| }
324324
""".stripMargin)
325325

326326
val inputVariables = args.map(_._2).mkString(", ")
327327
s"$doAggValFuncName($inputVariables);"
328328
} else {
329-
evalAndUpdateCodes(i)
329+
codes(i)
330330
}
331331
}
332332
}
@@ -377,7 +377,10 @@ case class HashAggregateExec(
377377
}
378378

379379
val updateAggValCode = splitAggregateExpressions(
380-
ctx, boundUpdateExpr, evalAndUpdateCodes, subExprs.states)
380+
context = ctx,
381+
aggregateExpressions = boundUpdateExpr,
382+
codes = evalAndUpdateCodes,
383+
subExprs = subExprs.states)
381384

382385
s"""
383386
| // do aggregate
@@ -946,11 +949,11 @@ case class HashAggregateExec(
946949
}
947950

948951
val updateAggValCode = splitAggregateExpressions(
949-
ctx,
950-
boundUpdateExpr,
951-
evalAndUpdateCodes,
952-
subExprs.states,
953-
Seq(("InternalRow", unsafeRowBuffer)))
952+
context = ctx,
953+
aggregateExpressions = boundUpdateExpr,
954+
codes = evalAndUpdateCodes,
955+
subExprs = subExprs.states,
956+
otherArgs = Seq(("InternalRow", unsafeRowBuffer)))
954957

955958
s"""
956959
| // do aggregate
@@ -991,11 +994,11 @@ case class HashAggregateExec(
991994
}
992995

993996
val updateAggValCode = splitAggregateExpressions(
994-
ctx,
995-
boundUpdateExpr,
996-
evalAndUpdateCodes,
997-
subExprs.states,
998-
Seq(("InternalRow", fastRowBuffer)))
997+
context = ctx,
998+
aggregateExpressions = boundUpdateExpr,
999+
codes = evalAndUpdateCodes,
1000+
subExprs = subExprs.states,
1001+
otherArgs = Seq(("InternalRow", fastRowBuffer)))
9991002

10001003
// If fast hash map is on, we first generate code to update row in fast hash map, if the
10011004
// previous loop up hit fast hash map. Otherwise, update row in regular hash map.

0 commit comments

Comments
 (0)