@@ -116,6 +116,14 @@ case class TungstenAggregate(
116116 // all the mode of aggregate expressions
117117 private val modes = aggregateExpressions.map(_.mode).distinct
118118
119+ override def references : AttributeSet = {
120+ AttributeSet (groupingExpressions.flatMap(_.references) ++ aggregateExpressions.flatMap {
121+ case AggregateExpression (f, Final | PartialMerge , _) => f.inputAggBufferAttributes
122+ case AggregateExpression (f, Partial | Complete , _) => f.references
123+ })
124+ child.outputSet
125+ }
126+
119127 override def supportCodegen : Boolean = {
120128 // ImperativeAggregate is not supported right now
121129 ! aggregateExpressions.exists(_.aggregateFunction.isInstanceOf [ImperativeAggregate ])
@@ -164,47 +172,47 @@ case class TungstenAggregate(
164172 """ .stripMargin
165173 ExprCode (ev.code + initVars, isNull, value)
166174 }
175+ val initBufVar = evaluateVariables(bufVars)
167176
168177 // generate variables for output
169- val bufferAttrs = functions.flatMap(_.aggBufferAttributes)
170178 val (resultVars, genResult) = if (modes.contains(Final ) || modes.contains(Complete )) {
171179 // evaluate aggregate results
172180 ctx.currentVars = bufVars
173181 val aggResults = functions.map(_.evaluateExpression).map { e =>
174- BindReferences .bindReference(e, bufferAttrs ).gen(ctx)
182+ BindReferences .bindReference(e, aggregateBufferAttributes ).gen(ctx)
175183 }
176184 // evaluate result expressions
177185 ctx.currentVars = aggResults
178186 val resultVars = resultExpressions.map { e =>
179187 BindReferences .bindReference(e, aggregateAttributes).gen(ctx)
180188 }
181189 (resultVars, s """
182- | ${aggResults.map(_.code).mkString( " \n " )}
183- | ${resultVars.map(_.code).mkString( " \n " )}
190+ | ${evaluateVariables(aggResults )}
191+ | ${evaluateVariables(resultVars )}
184192 """ .stripMargin)
185193 } else if (modes.contains(Partial ) || modes.contains(PartialMerge )) {
186194 // output the aggregate buffer directly
187195 (bufVars, " " )
188196 } else {
189197 // no aggregate function, the result should be literals
190198 val resultVars = resultExpressions.map(_.gen(ctx))
191- (resultVars, resultVars.map(_.code).mkString( " \n " ))
199+ (resultVars, evaluateVariables(resultVars ))
192200 }
193201
194202 val doAgg = ctx.freshName(" doAggregateWithoutKey" )
195203 ctx.addNewFunction(doAgg,
196204 s """
197205 | private void $doAgg() throws java.io.IOException {
198206 | // initialize aggregation buffer
199- | ${bufVars.map(_.code).mkString( " \n " )}
207+ | $initBufVar
200208 |
201209 | ${child.asInstanceOf [CodegenSupport ].produce(ctx, this )}
202210 | }
203211 """ .stripMargin)
204212
205213 val numOutput = metricTerm(ctx, " numOutputRows" )
206214 s """
207- | if (! $initAgg) {
215+ | while (! $initAgg) {
208216 | $initAgg = true;
209217 | $doAgg();
210218 |
@@ -241,7 +249,7 @@ case class TungstenAggregate(
241249 }
242250 s """
243251 | // do aggregate
244- | ${aggVals.map(_.code).mkString( " \n " ).trim }
252+ | ${evaluateVariables(aggVals) }
245253 | // update aggregation buffer
246254 | ${updates.mkString(" \n " ).trim}
247255 """ .stripMargin
@@ -252,8 +260,7 @@ case class TungstenAggregate(
252260 private val declFunctions = aggregateExpressions.map(_.aggregateFunction)
253261 .filter(_.isInstanceOf [DeclarativeAggregate ])
254262 .map(_.asInstanceOf [DeclarativeAggregate ])
255- private val bufferAttributes = declFunctions.flatMap(_.aggBufferAttributes)
256- private val bufferSchema = StructType .fromAttributes(bufferAttributes)
263+ private val bufferSchema = StructType .fromAttributes(aggregateBufferAttributes)
257264
258265 // The name for HashMap
259266 private var hashMapTerm : String = _
@@ -318,7 +325,7 @@ case class TungstenAggregate(
318325 val mergeExpr = declFunctions.flatMap(_.mergeExpressions)
319326 val mergeProjection = newMutableProjection(
320327 mergeExpr,
321- bufferAttributes ++ declFunctions.flatMap(_.inputAggBufferAttributes),
328+ aggregateBufferAttributes ++ declFunctions.flatMap(_.inputAggBufferAttributes),
322329 subexpressionEliminationEnabled)()
323330 val joinedRow = new JoinedRow ()
324331
@@ -381,13 +388,13 @@ case class TungstenAggregate(
381388 BoundReference (i, e.dataType, e.nullable).gen(ctx)
382389 }
383390 ctx.INPUT_ROW = bufferTerm
384- val bufferVars = bufferAttributes .zipWithIndex.map { case (e, i) =>
391+ val bufferVars = aggregateBufferAttributes .zipWithIndex.map { case (e, i) =>
385392 BoundReference (i, e.dataType, e.nullable).gen(ctx)
386393 }
387394 // evaluate the aggregation result
388395 ctx.currentVars = bufferVars
389396 val aggResults = declFunctions.map(_.evaluateExpression).map { e =>
390- BindReferences .bindReference(e, bufferAttributes ).gen(ctx)
397+ BindReferences .bindReference(e, aggregateBufferAttributes ).gen(ctx)
391398 }
392399 // generate the final result
393400 ctx.currentVars = keyVars ++ aggResults
@@ -396,11 +403,9 @@ case class TungstenAggregate(
396403 BindReferences .bindReference(e, inputAttrs).gen(ctx)
397404 }
398405 s """
399- ${keyVars.map(_.code).mkString(" \n " )}
400- ${bufferVars.map(_.code).mkString(" \n " )}
401- ${aggResults.map(_.code).mkString(" \n " )}
402- ${resultVars.map(_.code).mkString(" \n " )}
403-
406+ ${evaluateVariables(keyVars)}
407+ ${evaluateVariables(bufferVars)}
408+ ${evaluateVariables(aggResults)}
404409 ${consume(ctx, resultVars)}
405410 """
406411
@@ -422,10 +427,7 @@ case class TungstenAggregate(
422427 val eval = resultExpressions.map{ e =>
423428 BindReferences .bindReference(e, groupingAttributes).gen(ctx)
424429 }
425- s """
426- ${eval.map(_.code).mkString(" \n " )}
427- ${consume(ctx, eval)}
428- """
430+ consume(ctx, eval)
429431 }
430432 }
431433
@@ -508,8 +510,8 @@ case class TungstenAggregate(
508510 ctx.currentVars = input
509511 val hashEval = BindReferences .bindReference(hashExpr, child.output).gen(ctx)
510512
511- val inputAttr = bufferAttributes ++ child.output
512- ctx.currentVars = new Array [ExprCode ](bufferAttributes .length) ++ input
513+ val inputAttr = aggregateBufferAttributes ++ child.output
514+ ctx.currentVars = new Array [ExprCode ](aggregateBufferAttributes .length) ++ input
513515 ctx.INPUT_ROW = buffer
514516 // TODO: support subexpression elimination
515517 val evals = updateExpr.map(BindReferences .bindReference(_, inputAttr).gen(ctx))
@@ -557,7 +559,7 @@ case class TungstenAggregate(
557559 $incCounter
558560
559561 // evaluate aggregate function
560- ${evals.map(_.code).mkString( " \n " ).trim }
562+ ${evaluateVariables(evals) }
561563 // update aggregate buffer
562564 ${updates.mkString(" \n " ).trim}
563565 """
0 commit comments