@@ -203,7 +203,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
203203 private def planAggregateWithoutDistinct (
204204 groupingExpressions : Seq [Expression ],
205205 aggregateExpressions : Seq [AggregateExpression2 ],
206- aggregateFunctionMap : Map [AggregateFunction2 , Attribute ],
206+ aggregateFunctionMap : Map [( AggregateFunction2 , Boolean ) , Attribute ],
207207 resultExpressions : Seq [NamedExpression ],
208208 child : SparkPlan ): Seq [SparkPlan ] = {
209209 // 1. Create an Aggregate Operator for partial aggregations.
@@ -241,12 +241,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
241241 }
242242 val finalAggregateAttributes =
243243 finalAggregateExpressions.map {
244- expr => aggregateFunctionMap(expr.aggregateFunction)
244+ expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct )
245245 }
246246 val rewrittenResultExpressions = resultExpressions.map { expr =>
247247 expr.transform {
248248 case agg : AggregateExpression2 =>
249- aggregateFunctionMap(agg.aggregateFunction).toAttribute
249+ aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct ).toAttribute
250250 case expression if groupExpressionMap.contains(expression) =>
251251 groupExpressionMap(expression).toAttribute
252252 }.asInstanceOf [NamedExpression ]
@@ -266,7 +266,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
266266 groupingExpressions : Seq [Expression ],
267267 functionsWithDistinct : Seq [AggregateExpression2 ],
268268 functionsWithoutDistinct : Seq [AggregateExpression2 ],
269- aggregateFunctionMap : Map [AggregateFunction2 , Attribute ],
269+ aggregateFunctionMap : Map [( AggregateFunction2 , Boolean ) , Attribute ],
270270 resultExpressions : Seq [NamedExpression ],
271271 child : SparkPlan ): Seq [SparkPlan ] = {
272272
@@ -306,7 +306,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
306306 val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg =>
307307 agg.aggregateFunction.bufferAttributes
308308 }
309- println(" namedDistinctColumnExpressions " + namedDistinctColumnExpressions)
310309 val partialAggregate =
311310 Aggregate2Sort (
312311 None : Option [Seq [Expression ]],
@@ -323,7 +322,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
323322 }
324323 val partialMergeAggregateAttributes =
325324 partialMergeAggregateExpressions.map {
326- expr => aggregateFunctionMap(expr.aggregateFunction)
325+ expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct )
327326 }
328327 val partialMergeAggregate =
329328 Aggregate2Sort (
@@ -336,34 +335,41 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
336335
337336 // 3. Create an Aggregate Operator for partial merge aggregations.
338337 val finalAggregateExpressions = functionsWithoutDistinct.map {
339- Need to replace the children to distinctColumnAttributes
340338 case AggregateExpression2 (aggregateFunction, mode, _) =>
341339 AggregateExpression2 (aggregateFunction, Final , false )
342340 }
343341 val finalAggregateAttributes =
344342 finalAggregateExpressions.map {
345- expr => aggregateFunctionMap(expr.aggregateFunction)
346- }
347- val completeAggregateExpressions = functionsWithDistinct.map {
348- case AggregateExpression2 (aggregateFunction, mode, _) =>
349- AggregateExpression2 (aggregateFunction, Complete , false )
350- }
351- val completeAggregateAttributes =
352- completeAggregateExpressions.map {
353- expr => aggregateFunctionMap(expr.aggregateFunction)
343+ expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)
354344 }
345+ val (completeAggregateExpressions, completeAggregateAttributes) = functionsWithDistinct.map {
346+ // Children of an AggregateFunction with DISTINCT keyword has already
347+ // been evaluated. At here, we need to replace original children
348+ // to AttributeReferences.
349+ case agg @ AggregateExpression2 (aggregateFunction, mode, isDistinct) =>
350+ val rewrittenAggregateFunction = aggregateFunction.transformDown {
351+ case expr if distinctColumnExpressionMap.contains(expr) =>
352+ distinctColumnExpressionMap(expr).toAttribute
353+ }.asInstanceOf [AggregateFunction2 ]
354+ // We rewrite the aggregate function to a non-distinct aggregation because
355+ // its input will have distinct arguments.
356+ val rewrittenAggregateExpression =
357+ AggregateExpression2 (rewrittenAggregateFunction, Complete , false )
358+
359+ val aggregateFunctionAttribute = aggregateFunctionMap(agg.aggregateFunction, isDistinct)
360+ (rewrittenAggregateExpression -> aggregateFunctionAttribute)
361+ }.unzip
355362
356363 val rewrittenResultExpressions = resultExpressions.map { expr =>
357364 expr.transform {
358365 case agg : AggregateExpression2 =>
359- aggregateFunctionMap(agg.aggregateFunction).toAttribute
366+ aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct ).toAttribute
360367 case expression if groupExpressionMap.contains(expression) =>
361368 groupExpressionMap(expression).toAttribute
362- case expression if distinctColumnExpressionMap.contains(expression) =>
363- distinctColumnExpressionMap(expression).toAttribute
364369 }.asInstanceOf [NamedExpression ]
365370 }
366371 val finalAndCompleteAggregate = FinalAndCompleteAggregate2Sort (
372+ namedGroupingAttributes ++ distinctColumnAttributes,
367373 namedGroupingAttributes,
368374 finalAggregateExpressions,
369375 finalAggregateAttributes,
@@ -378,7 +384,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
378384 def apply (plan : LogicalPlan ): Seq [SparkPlan ] = plan match {
379385 case logical.Aggregate (groupingExpressions, resultExpressions, child)
380386 if sqlContext.conf.useSqlAggregate2 =>
381- // 1. Extracts all distinct aggregate expressions from the resultExpressions.
387+ // Extracts all distinct aggregate expressions from the resultExpressions.
382388 val aggregateExpressions = resultExpressions.flatMap { expr =>
383389 expr.collect {
384390 case agg : AggregateExpression2 => agg
@@ -388,12 +394,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
388394 // to the corresponding attribute of the function.
389395 val aggregateFunctionMap = aggregateExpressions.map { agg =>
390396 val aggregateFunction = agg.aggregateFunction
391- aggregateFunction -> Alias (aggregateFunction, aggregateFunction.toString)().toAttribute
397+ (aggregateFunction, agg.isDistinct) ->
398+ Alias (aggregateFunction, aggregateFunction.toString)().toAttribute
392399 }.toMap
393400
394401 val (functionsWithDistinct, functionsWithoutDistinct) =
395402 aggregateExpressions.partition(_.isDistinct)
396- println(" functionsWithDistinct " + functionsWithDistinct)
397403 if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1 ) {
398404 // This is a sanity check. We should not reach here since we check the same thing in
399405 // CheckAggregateFunction.
0 commit comments