@@ -20,8 +20,9 @@ package org.apache.spark.sql.catalyst.expressions.aggregate
2020import org .apache .spark .sql .AnalysisException
2121import org .apache .spark .sql .catalyst ._
2222import org .apache .spark .sql .catalyst .expressions ._
23- import org .apache .spark .sql .catalyst .plans .logical .{Aggregate , LogicalPlan }
24- import org .apache .spark .sql .types .{StructType , MapType , ArrayType }
23+ import org .apache .spark .sql .catalyst .plans .logical .{Expand , Aggregate , LogicalPlan }
24+ import org .apache .spark .sql .catalyst .rules .Rule
25+ import org .apache .spark .sql .types .{IntegerType , StructType , MapType , ArrayType }
2526
2627/**
2728 * Utility functions used by the query planner to convert our plan to new aggregation code path.
@@ -41,7 +42,7 @@ object Utils {
4142
4243 private def doConvert (plan : LogicalPlan ): Option [Aggregate ] = plan match {
4344 case p : Aggregate if supportsGroupingKeySchema(p) =>
44- val converted = p.transformExpressionsDown {
45+ val converted = MultipleDistinctRewriter .rewrite( p.transformExpressionsDown {
4546 case expressions.Average (child) =>
4647 aggregate.AggregateExpression2 (
4748 aggregateFunction = aggregate.Average (child),
@@ -144,7 +145,8 @@ object Utils {
144145 aggregateFunction = aggregate.VarianceSamp (child),
145146 mode = aggregate.Complete ,
146147 isDistinct = false )
147- }
148+ })
149+
148150 // Check if there is any expressions.AggregateExpression1 left.
149151 // If so, we cannot convert this plan.
150152 val hasAggregateExpression1 = converted.aggregateExpressions.exists { expr =>
@@ -156,6 +158,7 @@ object Utils {
156158 }
157159
158160 // Check if there are multiple distinct columns.
161+ // TODO remove this.
159162 val aggregateExpressions = converted.aggregateExpressions.flatMap { expr =>
160163 expr.collect {
161164 case agg : AggregateExpression2 => agg
@@ -213,3 +216,178 @@ object Utils {
213216 case other => None
214217 }
215218}
219+
220+ /**
221+ * This rule rewrites an aggregate query with multiple distinct clauses into an expanded double
222+ * aggregation in which the regular aggregation expressions and every distinct clause is aggregated
223+ * in a separate group. The results are then combined in a second aggregate.
224+ *
225+ * TODO Expression cannocalization
226+ * TODO Eliminate foldable expressions from distinct clauses.
227+ * TODO This eliminates all distinct expressions. We could safely pass one to the aggregate
228+ * operator. Perhaps this is a good thing? It is much simpler to plan later on...
229+ */
230+ object MultipleDistinctRewriter extends Rule [LogicalPlan ] {
231+
232+ def apply (plan : LogicalPlan ): LogicalPlan = plan resolveOperators {
233+ case a : Aggregate => rewrite(a)
234+ case p => p
235+ }
236+
237+ def rewrite (a : Aggregate ): Aggregate = {
238+
239+ // Collect all aggregate expressions.
240+ val aggExpressions = a.aggregateExpressions.flatMap { e =>
241+ e.collect {
242+ case ae : AggregateExpression2 => ae
243+ }
244+ }
245+
246+ // Extract distinct aggregate expressions.
247+ val distinctAggGroups = aggExpressions
248+ .filter(_.isDistinct)
249+ .groupBy(_.aggregateFunction.children.toSet)
250+
251+ // Only continue to rewrite if there is more than one distinct group.
252+ if (distinctAggGroups.size > 1 ) {
253+ // Create the attributes for the grouping id and the group by clause.
254+ val gid = new AttributeReference (" gid" , IntegerType , false )()
255+ val groupByMap = a.groupingExpressions.collect {
256+ case ne : NamedExpression => ne -> ne.toAttribute
257+ case e => e -> new AttributeReference (e.prettyName, e.dataType, e.nullable)()
258+ }
259+ val groupByAttrs = groupByMap.map(_._2)
260+
261+ // Functions used to modify aggregate functions and their inputs.
262+ def evalWithinGroup (id : Literal , e : Expression ) = If (EqualTo (gid, id), e, nullify(e))
263+ def patchAggregateFunctionChildren (
264+ af : AggregateFunction2 ,
265+ id : Literal ,
266+ attrs : Map [Expression , Expression ]): AggregateFunction2 = {
267+ af.withNewChildren(af.children.map { case afc =>
268+ evalWithinGroup(id, attrs(afc))
269+ }).asInstanceOf [AggregateFunction2 ]
270+ }
271+
272+ // Setup unique distinct aggregate children.
273+ val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq
274+ val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair).toMap
275+ val distinctAggChildAttrs = distinctAggChildAttrMap.values.toSeq
276+
277+ // Setup expand & aggregate operators for distinct aggregate expressions.
278+ val distinctAggOperatorMap = distinctAggGroups.toSeq.zipWithIndex.map {
279+ case ((group, expressions), i) =>
280+ val id = Literal (i + 1 )
281+
282+ // Expand projection
283+ val projection = distinctAggChildren.map {
284+ case e if group.contains(e) => e
285+ case e => nullify(e)
286+ } :+ id
287+
288+ // Final aggregate
289+ val operators = expressions.map { e =>
290+ val af = e.aggregateFunction
291+ val naf = patchAggregateFunctionChildren(af, id, distinctAggChildAttrMap)
292+ (e, e.copy(aggregateFunction = naf, isDistinct = false ))
293+ }
294+
295+ (projection, operators)
296+ }
297+
298+ // Setup expand for the 'regular' aggregate expressions.
299+ val regularAggExprs = aggExpressions.filter(! _.isDistinct)
300+ val regularAggChildren = regularAggExprs.flatMap(_.aggregateFunction.children).distinct
301+ val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair).toMap
302+
303+ // Setup aggregates for 'regular' aggregate expressions.
304+ val regularGroupId = Literal (0 )
305+ val regularAggOperatorMap = regularAggExprs.map { e =>
306+ // Perform the actual aggregation in the initial aggregate.
307+ val af = patchAggregateFunctionChildren(
308+ e.aggregateFunction,
309+ regularGroupId,
310+ regularAggChildAttrMap)
311+ val a = Alias (e.copy(aggregateFunction = af), e.toString)()
312+
313+ // Get the result of the first aggregate in the last aggregate.
314+ val b = AggregateExpression2 (
315+ aggregate.First (evalWithinGroup(regularGroupId, a.toAttribute), Literal (true )),
316+ mode = Complete ,
317+ isDistinct = false )
318+
319+ // Some aggregate functions (COUNT) have the special property that they can return a
320+ // non-null result without any input. We need to make sure we return a result in this case.
321+ val c = af.defaultResult match {
322+ case Some (lit) => Coalesce (Seq (b, lit))
323+ case None => b
324+ }
325+
326+ (e, a, c)
327+ }
328+
329+ // Construct the regular aggregate input projection only if we need one.
330+ val regularAggProjection = if (regularAggExprs.nonEmpty) {
331+ Seq (a.groupingExpressions ++
332+ distinctAggChildren.map(nullify) ++
333+ Seq (regularGroupId) ++
334+ regularAggChildren)
335+ } else {
336+ Seq .empty[Seq [Expression ]]
337+ }
338+
339+ // Construct the distinct aggregate input projections.
340+ val regularAggNulls = regularAggChildren.map(nullify)
341+ val distinctAggProjections = distinctAggOperatorMap.map {
342+ case (projection, _) =>
343+ a.groupingExpressions ++
344+ projection ++
345+ regularAggNulls
346+ }
347+
348+ // Construct the expand operator.
349+ val expand = Expand (
350+ regularAggProjection ++ distinctAggProjections,
351+ groupByAttrs ++ distinctAggChildAttrs ++ Seq (gid) ++ regularAggChildAttrMap.values.toSeq,
352+ a.child)
353+
354+ // Construct the first aggregate operator. This de-duplicates the all the children of
355+ // distinct operators, and applies the regular aggregate operators.
356+ val firstAggregateGroupBy = groupByAttrs ++ distinctAggChildAttrs :+ gid
357+ val firstAggregate = Aggregate (
358+ firstAggregateGroupBy,
359+ firstAggregateGroupBy ++ regularAggOperatorMap.map(_._2),
360+ expand)
361+
362+ // Construct the second aggregate
363+ val transformations : Map [Expression , Expression ] =
364+ (distinctAggOperatorMap.flatMap(_._2) ++
365+ regularAggOperatorMap.map(e => (e._1, e._3))).toMap
366+
367+ val patchedAggExpressions = a.aggregateExpressions.map { e =>
368+ e.transformDown {
369+ case e : Expression =>
370+ // The same GROUP BY clauses can have different forms (different names for instance) in
371+ // the groupBy and aggregate expressions of an aggregate. This makes a map lookup
372+ // tricky. So we do a linear search for a semantically equal group by expression.
373+ groupByMap
374+ .find(ge => e.semanticEquals(ge._1))
375+ .map(_._2)
376+ .getOrElse(transformations.getOrElse(e, e))
377+ }.asInstanceOf [NamedExpression ]
378+ }
379+ Aggregate (groupByAttrs, patchedAggExpressions, firstAggregate)
380+ } else {
381+ a
382+ }
383+ }
384+
385+ private def nullify (e : Expression ) = Literal .create(null , e.dataType)
386+
387+ private def expressionAttributePair (e : Expression ) =
388+ // We are creating a new reference here instead of reusing the attribute in case of a
389+ // NamedExpression. This is done to prevent collisions between distinct and regular aggregate
390+ // children, in this case attribute reuse causes the input of the regular aggregate to bound to
391+ // the (nulled out) input of the distinct aggregate.
392+ e -> new AttributeReference (e.prettyName, e.dataType, true )()
393+ }
0 commit comments