@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.plans._
2525import org .apache .spark .sql .catalyst .plans .logical .{BroadcastHint , LogicalPlan }
2626import org .apache .spark .sql .catalyst .plans .physical ._
2727import org .apache .spark .sql .columnar .{InMemoryColumnarTableScan , InMemoryRelation }
28- import org .apache .spark .sql .execution .aggregate2 .Aggregate2Sort
28+ import org .apache .spark .sql .execution .aggregate2 .{ FinalAndCompleteAggregate2Sort , Aggregate2Sort }
2929import org .apache .spark .sql .execution .{DescribeCommand => RunnableDescribeCommand }
3030import org .apache .spark .sql .parquet ._
3131import org .apache .spark .sql .sources .{CreateTableUsing , CreateTempTableUsing , DescribeCommand => LogicalDescribeCommand , _ }
@@ -200,6 +200,181 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
200200 * Used to plan the aggregate operator for expressions based on the AggregateFunction2 interface.
201201 */
202202 object AggregateOperator2 extends Strategy {
203+ private def planAggregateWithoutDistinct (
204+ groupingExpressions : Seq [Expression ],
205+ aggregateExpressions : Seq [AggregateExpression2 ],
206+ aggregateFunctionMap : Map [AggregateFunction2 , Attribute ],
207+ resultExpressions : Seq [NamedExpression ],
208+ child : SparkPlan ): Seq [SparkPlan ] = {
209+ // 1. Create an Aggregate Operator for partial aggregations.
210+ val namedGroupingExpressions = groupingExpressions.map {
211+ case ne : NamedExpression => ne -> ne
212+ // If the expression is not a NamedExpressions, we add an alias.
213+ // So, when we generate the result of the operator, the Aggregate Operator
214+ // can directly get the Seq of attributes representing the grouping expressions.
215+ case other =>
216+ val withAlias = Alias (other, other.toString)()
217+ other -> withAlias
218+ }
219+ val groupExpressionMap = namedGroupingExpressions.toMap
220+ val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute)
221+ val partialAggregateExpressions = aggregateExpressions.map {
222+ case AggregateExpression2 (aggregateFunction, mode, isDistinct) =>
223+ AggregateExpression2 (aggregateFunction, Partial , isDistinct)
224+ }
225+ val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg =>
226+ agg.aggregateFunction.bufferAttributes
227+ }
228+ val partialAggregate =
229+ Aggregate2Sort (
230+ None : Option [Seq [Expression ]],
231+ namedGroupingExpressions.map(_._2),
232+ partialAggregateExpressions,
233+ partialAggregateAttributes,
234+ namedGroupingAttributes ++ partialAggregateAttributes,
235+ child)
236+
237+ // 2. Create an Aggregate Operator for final aggregations.
238+ val finalAggregateExpressions = aggregateExpressions.map {
239+ case AggregateExpression2 (aggregateFunction, mode, isDistinct) =>
240+ AggregateExpression2 (aggregateFunction, Final , isDistinct)
241+ }
242+ val finalAggregateAttributes =
243+ finalAggregateExpressions.map {
244+ expr => aggregateFunctionMap(expr.aggregateFunction)
245+ }
246+ val rewrittenResultExpressions = resultExpressions.map { expr =>
247+ expr.transform {
248+ case agg : AggregateExpression2 =>
249+ aggregateFunctionMap(agg.aggregateFunction).toAttribute
250+ case expression if groupExpressionMap.contains(expression) =>
251+ groupExpressionMap(expression).toAttribute
252+ }.asInstanceOf [NamedExpression ]
253+ }
254+ val finalAggregate = Aggregate2Sort (
255+ Some (namedGroupingAttributes),
256+ namedGroupingAttributes,
257+ finalAggregateExpressions,
258+ finalAggregateAttributes,
259+ rewrittenResultExpressions,
260+ partialAggregate)
261+
262+ finalAggregate :: Nil
263+ }
264+
265+ private def planAggregateWithOneDistinct (
266+ groupingExpressions : Seq [Expression ],
267+ functionsWithDistinct : Seq [AggregateExpression2 ],
268+ functionsWithoutDistinct : Seq [AggregateExpression2 ],
269+ aggregateFunctionMap : Map [AggregateFunction2 , Attribute ],
270+ resultExpressions : Seq [NamedExpression ],
271+ child : SparkPlan ): Seq [SparkPlan ] = {
272+
273+ // 1. Create an Aggregate Operator for partial aggregations.
274+ // The grouping expressions are original groupingExpressions and
275+ // distinct columns. For example, for avg(distinct value) ... group by key
276+ // the grouping expressions of this Aggregate Operator will be [key, value].
277+ val namedGroupingExpressions = groupingExpressions.map {
278+ case ne : NamedExpression => ne -> ne
279+ // If the expression is not a NamedExpressions, we add an alias.
280+ // So, when we generate the result of the operator, the Aggregate Operator
281+ // can directly get the Seq of attributes representing the grouping expressions.
282+ case other =>
283+ val withAlias = Alias (other, other.toString)()
284+ other -> withAlias
285+ }
286+ val groupExpressionMap = namedGroupingExpressions.toMap
287+ val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute)
288+
289+ // It is safe to call head at here since functionsWithDistinct has at least one
290+ // AggregateExpression2.
291+ val distinctColumnExpressions =
292+ functionsWithDistinct.head.aggregateFunction.children
293+ val namedDistinctColumnExpressions = distinctColumnExpressions.map {
294+ case ne : NamedExpression => ne -> ne
295+ case other =>
296+ val withAlias = Alias (other, other.toString)()
297+ other -> withAlias
298+ }
299+ val distinctColumnExpressionMap = namedDistinctColumnExpressions.toMap
300+ val distinctColumnAttributes = namedDistinctColumnExpressions.map(_._2.toAttribute)
301+
302+ val partialAggregateExpressions = functionsWithoutDistinct.map {
303+ case AggregateExpression2 (aggregateFunction, mode, _) =>
304+ AggregateExpression2 (aggregateFunction, Partial , false )
305+ }
306+ val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg =>
307+ agg.aggregateFunction.bufferAttributes
308+ }
309+ println(" namedDistinctColumnExpressions " + namedDistinctColumnExpressions)
310+ val partialAggregate =
311+ Aggregate2Sort (
312+ None : Option [Seq [Expression ]],
313+ (namedGroupingExpressions ++ namedDistinctColumnExpressions).map(_._2),
314+ partialAggregateExpressions,
315+ partialAggregateAttributes,
316+ namedGroupingAttributes ++ distinctColumnAttributes ++ partialAggregateAttributes,
317+ child)
318+
319+ // 2. Create an Aggregate Operator for partial merge aggregations.
320+ val partialMergeAggregateExpressions = functionsWithoutDistinct.map {
321+ case AggregateExpression2 (aggregateFunction, mode, _) =>
322+ AggregateExpression2 (aggregateFunction, PartialMerge , false )
323+ }
324+ val partialMergeAggregateAttributes =
325+ partialMergeAggregateExpressions.map {
326+ expr => aggregateFunctionMap(expr.aggregateFunction)
327+ }
328+ val partialMergeAggregate =
329+ Aggregate2Sort (
330+ Some (namedGroupingAttributes),
331+ namedGroupingAttributes ++ distinctColumnAttributes,
332+ partialMergeAggregateExpressions,
333+ partialMergeAggregateAttributes,
334+ namedGroupingAttributes ++ distinctColumnAttributes ++ partialMergeAggregateAttributes,
335+ partialAggregate)
336+
337+ // 3. Create an Aggregate Operator for partial merge aggregations.
338+ val finalAggregateExpressions = functionsWithoutDistinct.map {
339+ Need to replace the children to distinctColumnAttributes
340+ case AggregateExpression2 (aggregateFunction, mode, _) =>
341+ AggregateExpression2 (aggregateFunction, Final , false )
342+ }
343+ val finalAggregateAttributes =
344+ finalAggregateExpressions.map {
345+ expr => aggregateFunctionMap(expr.aggregateFunction)
346+ }
347+ val completeAggregateExpressions = functionsWithDistinct.map {
348+ case AggregateExpression2 (aggregateFunction, mode, _) =>
349+ AggregateExpression2 (aggregateFunction, Complete , false )
350+ }
351+ val completeAggregateAttributes =
352+ completeAggregateExpressions.map {
353+ expr => aggregateFunctionMap(expr.aggregateFunction)
354+ }
355+
356+ val rewrittenResultExpressions = resultExpressions.map { expr =>
357+ expr.transform {
358+ case agg : AggregateExpression2 =>
359+ aggregateFunctionMap(agg.aggregateFunction).toAttribute
360+ case expression if groupExpressionMap.contains(expression) =>
361+ groupExpressionMap(expression).toAttribute
362+ case expression if distinctColumnExpressionMap.contains(expression) =>
363+ distinctColumnExpressionMap(expression).toAttribute
364+ }.asInstanceOf [NamedExpression ]
365+ }
366+ val finalAndCompleteAggregate = FinalAndCompleteAggregate2Sort (
367+ namedGroupingAttributes,
368+ finalAggregateExpressions,
369+ finalAggregateAttributes,
370+ completeAggregateExpressions,
371+ completeAggregateAttributes,
372+ rewrittenResultExpressions,
373+ partialMergeAggregate)
374+
375+ finalAndCompleteAggregate :: Nil
376+ }
377+
203378 def apply (plan : LogicalPlan ): Seq [SparkPlan ] = plan match {
204379 case logical.Aggregate (groupingExpressions, resultExpressions, child)
205380 if sqlContext.conf.useSqlAggregate2 =>
@@ -216,58 +391,33 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
216391 aggregateFunction -> Alias (aggregateFunction, aggregateFunction.toString)().toAttribute
217392 }.toMap
218393
219- // 2. Create an Aggregate Operator for partial aggregations.
220- val namedGroupingExpressions = groupingExpressions.map {
221- case ne : NamedExpression => ne -> ne
222- // If the expression is not a NamedExpressions, we add an alias.
223- // So, when we generate the result of the operator, the Aggregate Operator
224- // can directly get the Seq of attributes representing the grouping expressions.
225- case other =>
226- val withAlias = Alias (other, other.toString)()
227- other -> withAlias
228- }
229- val groupExpressionMap = namedGroupingExpressions.toMap
230- val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute)
231- val partialAggregateExpressions = aggregateExpressions.map {
232- case AggregateExpression2 (aggregateFunction, mode, isDistinct) =>
233- AggregateExpression2 (aggregateFunction, Partial , isDistinct)
394+ val (functionsWithDistinct, functionsWithoutDistinct) =
395+ aggregateExpressions.partition(_.isDistinct)
396+ println(" functionsWithDistinct " + functionsWithDistinct)
397+ if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1 ) {
398+ // This is a sanity check. We should not reach here since we check the same thing in
399+ // CheckAggregateFunction.
400+ sys.error(" Having more than one distinct column sets is not allowed." )
234401 }
235- val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg =>
236- agg.aggregateFunction.bufferAttributes
237- }
238- val partialAggregate =
239- Aggregate2Sort (
240- namedGroupingExpressions.map(_._2),
241- partialAggregateExpressions,
242- partialAggregateAttributes,
243- namedGroupingAttributes ++ partialAggregateAttributes,
244- planLater(child))
245-
246- // 3. Create an Aggregate Operator for final aggregations.
247- val finalAggregateExpressions = aggregateExpressions.map {
248- case AggregateExpression2 (aggregateFunction, mode, isDistinct) =>
249- AggregateExpression2 (aggregateFunction, Final , isDistinct)
250- }
251- val finalAggregateAttributes =
252- finalAggregateExpressions.map {
253- expr => aggregateFunctionMap(expr.aggregateFunction)
402+ val aggregate =
403+ if (functionsWithDistinct.isEmpty) {
404+ planAggregateWithoutDistinct(
405+ groupingExpressions,
406+ aggregateExpressions,
407+ aggregateFunctionMap,
408+ resultExpressions,
409+ planLater(child))
410+ } else {
411+ planAggregateWithOneDistinct(
412+ groupingExpressions,
413+ functionsWithDistinct,
414+ functionsWithoutDistinct,
415+ aggregateFunctionMap,
416+ resultExpressions,
417+ planLater(child))
254418 }
255- val rewrittenResultExpressions = resultExpressions.map { expr =>
256- expr.transform {
257- case agg : AggregateExpression2 =>
258- aggregateFunctionMap(agg.aggregateFunction).toAttribute
259- case expression if groupExpressionMap.contains(expression) =>
260- groupExpressionMap(expression).toAttribute
261- }.asInstanceOf [NamedExpression ]
262- }
263- val finalAggregate = Aggregate2Sort (
264- namedGroupingAttributes,
265- finalAggregateExpressions,
266- finalAggregateAttributes,
267- rewrittenResultExpressions,
268- partialAggregate)
269419
270- finalAggregate :: Nil
420+ aggregate
271421 case _ => Nil
272422 }
273423 }
0 commit comments