@@ -222,10 +222,76 @@ object Utils {
222222 * aggregation in which the regular aggregation expressions and every distinct clause is aggregated
223223 * in a separate group. The results are then combined in a second aggregate.
224224 *
225- * TODO Expression cannocalization
226- * TODO Eliminate foldable expressions from distinct clauses.
227- * TODO This eliminates all distinct expressions. We could safely pass one to the aggregate
228- * operator. Perhaps this is a good thing? It is much simpler to plan later on...
225+ * For example (in scala):
226+ * {{{
227+ * val data = Seq(
228+ * ("a", "ca1", "cb1", 10),
229+ * ("a", "ca1", "cb2", 5),
230+ * ("b", "ca1", "cb1", 13))
231+ * .toDF("key", "cat1", "cat2", "value")
232+ * data.registerTempTable("data")
233+ *
234+ * val agg = data.groupBy($"key")
235+ * .agg(
236+ * countDistinct($"cat1").as("cat1_cnt"),
237+ * countDistinct($"cat2").as("cat2_cnt"),
238+ * sum($"value").as("total"))
239+ * }}}
240+ *
241+ * This translates to the following (pseudo) logical plan:
242+ * {{{
243+ * Aggregate(
244+ * key = ['key]
245+ * functions = [COUNT(DISTINCT 'cat1),
246+ * COUNT(DISTINCT 'cat2),
247+ * sum('value)]
248+ * output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
249+ * LocalTableScan [...]
250+ * }}}
251+ *
252+ * This rule rewrites this logical plan to the following (pseudo) logical plan:
253+ * {{{
254+ * Aggregate(
255+ * key = ['key]
256+ * functions = [count(if (('gid = 1)) 'cat1 else null),
257+ * count(if (('gid = 2)) 'cat2 else null),
258+ * first(if (('gid = 0)) 'total else null) ignore nulls]
259+ * output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
260+ * Aggregate(
261+ * key = ['key, 'cat1, 'cat2, 'gid]
262+ * functions = [sum('value)]
263+ * output = ['key, 'cat1, 'cat2, 'gid, 'total])
264+ * Expand(
265+ * projections = [('key, null, null, 0, cast('value as bigint)),
266+ * ('key, 'cat1, null, 1, null),
267+ * ('key, null, 'cat2, 2, null)]
268+ * output = ['key, 'cat1, 'cat2, 'gid, 'value])
269+ * LocalTableScan [...]
270+ * }}}
271+ *
272+ * The rule does the following things here:
273+ * 1. Expand the data. There are three aggregation groups in this query:
274+ * i. the non-distinct group;
275+ * ii. the distinct 'cat1 group;
276+ * iii. the distinct 'cat2 group.
277+ * An expand operator is inserted to expand the child data for each group. The expand will null
278+ * out all unused columns for the given group; this must be done in order to ensure correctness
279+ * later on. Groups can by identified by a group id (gid) column added by the expand operator.
280+ * 2. De-duplicate the distinct paths and aggregate the non-aggregate path. The group by clause of
281+ * this aggregate consists of the original group by clause, all the requested distinct columns
282+ * and the group id. Both de-duplication of distinct column and the aggregation of the
283+ * non-distinct group take advantage of the fact that we group by the group id (gid) and that we
284+ * have nulled out all non-relevant columns for the the given group.
285+ * 3. Aggregating the distinct groups and combining this with the results of the non-distinct
286+ * aggregation. In this step we use the group id to filter the inputs for the aggregate
287+ * functions. The result of the non-distinct group are 'aggregated' by using the first operator,
288+ * it might be more elegant to use the native UDAF merge mechanism for this in the future.
289+ *
290+ * This rule duplicates the input data by two or more times (# distinct groups + an optional
291+ * non-distinct group). This will put quite a bit of memory pressure of the used aggregate and
292+ * exchange operators. Keeping the number of distinct groups as low a possible should be priority,
293+ * we could improve this in the current rule by applying more advanced expression cannocalization
294+ * techniques.
229295 */
230296object MultipleDistinctRewriter extends Rule [LogicalPlan ] {
231297
@@ -261,11 +327,10 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] {
261327 // Functions used to modify aggregate functions and their inputs.
262328 def evalWithinGroup (id : Literal , e : Expression ) = If (EqualTo (gid, id), e, nullify(e))
263329 def patchAggregateFunctionChildren (
264- af : AggregateFunction2 ,
265- id : Literal ,
266- attrs : Map [Expression , Expression ]): AggregateFunction2 = {
267- af.withNewChildren(af.children.map { case afc =>
268- evalWithinGroup(id, attrs(afc))
330+ af : AggregateFunction2 )(
331+ attrs : Expression => Expression ): AggregateFunction2 = {
332+ af.withNewChildren(af.children.map {
333+ case afc => attrs(afc)
269334 }).asInstanceOf [AggregateFunction2 ]
270335 }
271336
@@ -288,7 +353,9 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] {
288353 // Final aggregate
289354 val operators = expressions.map { e =>
290355 val af = e.aggregateFunction
291- val naf = patchAggregateFunctionChildren(af, id, distinctAggChildAttrMap)
356+ val naf = patchAggregateFunctionChildren(af) { x =>
357+ evalWithinGroup(id, distinctAggChildAttrMap(x))
358+ }
292359 (e, e.copy(aggregateFunction = naf, isDistinct = false ))
293360 }
294361
@@ -304,26 +371,27 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] {
304371 val regularGroupId = Literal (0 )
305372 val regularAggOperatorMap = regularAggExprs.map { e =>
306373 // Perform the actual aggregation in the initial aggregate.
307- val af = patchAggregateFunctionChildren(
308- e.aggregateFunction,
309- regularGroupId,
310- regularAggChildAttrMap)
311- val a = Alias (e.copy(aggregateFunction = af), e.toString)()
312-
313- // Get the result of the first aggregate in the last aggregate.
314- val b = AggregateExpression2 (
315- aggregate.First (evalWithinGroup(regularGroupId, a.toAttribute), Literal (true )),
374+ val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrMap)
375+ val operator = Alias (e.copy(aggregateFunction = af), e.toString)()
376+
377+ // Select the result of the first aggregate in the last aggregate.
378+ val result = AggregateExpression2 (
379+ aggregate.First (evalWithinGroup(regularGroupId, operator.toAttribute), Literal (true )),
316380 mode = Complete ,
317381 isDistinct = false )
318382
319383 // Some aggregate functions (COUNT) have the special property that they can return a
320384 // non-null result without any input. We need to make sure we return a result in this case.
321- val c = af.defaultResult match {
322- case Some (lit) => Coalesce (Seq (b , lit))
323- case None => b
385+ val resultWithDefault = af.defaultResult match {
386+ case Some (lit) => Coalesce (Seq (result , lit))
387+ case None => result
324388 }
325389
326- (e, a, c)
390+ // Return a Tuple3 containing:
391+ // i. The original aggregate expression (used for look ups).
392+ // ii. The actual aggregation operator (used in the first aggregate).
393+ // iii. The operator that selects and returns the result (used in the second aggregate).
394+ (e, operator, resultWithDefault)
327395 }
328396
329397 // Construct the regular aggregate input projection only if we need one.
0 commit comments