Skip to content

Commit ef36284

Browse files
hvanhovellyhuai
authored andcommitted
[SPARK-9241][SQL] Supporting multiple DISTINCT columns - follow-up
This PR is a follow up for PR #9406. It adds more documentation to the rewriting rule, removes a redundant if expression in the non-distinct aggregation path and adds a multiple distinct test to the AggregationQuerySuite. cc yhuai marmbrus Author: Herman van Hovell <[email protected]> Closes #9541 from hvanhovell/SPARK-9241-followup.
1 parent 2ff0e79 commit ef36284

File tree

2 files changed

+108
-23
lines changed

2 files changed

+108
-23
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala

Lines changed: 91 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -222,10 +222,76 @@ object Utils {
222222
* aggregation in which the regular aggregation expressions and every distinct clause is aggregated
223223
* in a separate group. The results are then combined in a second aggregate.
224224
*
225-
* TODO Expression cannocalization
226-
* TODO Eliminate foldable expressions from distinct clauses.
227-
* TODO This eliminates all distinct expressions. We could safely pass one to the aggregate
228-
* operator. Perhaps this is a good thing? It is much simpler to plan later on...
225+
* For example (in scala):
226+
* {{{
227+
* val data = Seq(
228+
* ("a", "ca1", "cb1", 10),
229+
* ("a", "ca1", "cb2", 5),
230+
* ("b", "ca1", "cb1", 13))
231+
* .toDF("key", "cat1", "cat2", "value")
232+
* data.registerTempTable("data")
233+
*
234+
* val agg = data.groupBy($"key")
235+
* .agg(
236+
* countDistinct($"cat1").as("cat1_cnt"),
237+
* countDistinct($"cat2").as("cat2_cnt"),
238+
* sum($"value").as("total"))
239+
* }}}
240+
*
241+
* This translates to the following (pseudo) logical plan:
242+
* {{{
243+
* Aggregate(
244+
* key = ['key]
245+
* functions = [COUNT(DISTINCT 'cat1),
246+
* COUNT(DISTINCT 'cat2),
247+
* sum('value)]
248+
* output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
249+
* LocalTableScan [...]
250+
* }}}
251+
*
252+
* This rule rewrites this logical plan to the following (pseudo) logical plan:
253+
* {{{
254+
* Aggregate(
255+
* key = ['key]
256+
* functions = [count(if (('gid = 1)) 'cat1 else null),
257+
* count(if (('gid = 2)) 'cat2 else null),
258+
* first(if (('gid = 0)) 'total else null) ignore nulls]
259+
* output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
260+
* Aggregate(
261+
* key = ['key, 'cat1, 'cat2, 'gid]
262+
* functions = [sum('value)]
263+
* output = ['key, 'cat1, 'cat2, 'gid, 'total])
264+
* Expand(
265+
* projections = [('key, null, null, 0, cast('value as bigint)),
266+
* ('key, 'cat1, null, 1, null),
267+
* ('key, null, 'cat2, 2, null)]
268+
* output = ['key, 'cat1, 'cat2, 'gid, 'value])
269+
* LocalTableScan [...]
270+
* }}}
271+
*
272+
* The rule does the following things here:
273+
* 1. Expand the data. There are three aggregation groups in this query:
274+
* i. the non-distinct group;
275+
* ii. the distinct 'cat1 group;
276+
* iii. the distinct 'cat2 group.
277+
* An expand operator is inserted to expand the child data for each group. The expand will null
278+
* out all unused columns for the given group; this must be done in order to ensure correctness
279+
* later on. Groups can by identified by a group id (gid) column added by the expand operator.
280+
* 2. De-duplicate the distinct paths and aggregate the non-aggregate path. The group by clause of
281+
* this aggregate consists of the original group by clause, all the requested distinct columns
282+
* and the group id. Both de-duplication of distinct column and the aggregation of the
283+
* non-distinct group take advantage of the fact that we group by the group id (gid) and that we
284+
* have nulled out all non-relevant columns for the the given group.
285+
* 3. Aggregating the distinct groups and combining this with the results of the non-distinct
286+
* aggregation. In this step we use the group id to filter the inputs for the aggregate
287+
* functions. The result of the non-distinct group are 'aggregated' by using the first operator,
288+
* it might be more elegant to use the native UDAF merge mechanism for this in the future.
289+
*
290+
* This rule duplicates the input data by two or more times (# distinct groups + an optional
291+
* non-distinct group). This will put quite a bit of memory pressure of the used aggregate and
292+
* exchange operators. Keeping the number of distinct groups as low a possible should be priority,
293+
* we could improve this in the current rule by applying more advanced expression cannocalization
294+
* techniques.
229295
*/
230296
object MultipleDistinctRewriter extends Rule[LogicalPlan] {
231297

@@ -261,11 +327,10 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] {
261327
// Functions used to modify aggregate functions and their inputs.
262328
def evalWithinGroup(id: Literal, e: Expression) = If(EqualTo(gid, id), e, nullify(e))
263329
def patchAggregateFunctionChildren(
264-
af: AggregateFunction2,
265-
id: Literal,
266-
attrs: Map[Expression, Expression]): AggregateFunction2 = {
267-
af.withNewChildren(af.children.map { case afc =>
268-
evalWithinGroup(id, attrs(afc))
330+
af: AggregateFunction2)(
331+
attrs: Expression => Expression): AggregateFunction2 = {
332+
af.withNewChildren(af.children.map {
333+
case afc => attrs(afc)
269334
}).asInstanceOf[AggregateFunction2]
270335
}
271336

@@ -288,7 +353,9 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] {
288353
// Final aggregate
289354
val operators = expressions.map { e =>
290355
val af = e.aggregateFunction
291-
val naf = patchAggregateFunctionChildren(af, id, distinctAggChildAttrMap)
356+
val naf = patchAggregateFunctionChildren(af) { x =>
357+
evalWithinGroup(id, distinctAggChildAttrMap(x))
358+
}
292359
(e, e.copy(aggregateFunction = naf, isDistinct = false))
293360
}
294361

@@ -304,26 +371,27 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] {
304371
val regularGroupId = Literal(0)
305372
val regularAggOperatorMap = regularAggExprs.map { e =>
306373
// Perform the actual aggregation in the initial aggregate.
307-
val af = patchAggregateFunctionChildren(
308-
e.aggregateFunction,
309-
regularGroupId,
310-
regularAggChildAttrMap)
311-
val a = Alias(e.copy(aggregateFunction = af), e.toString)()
312-
313-
// Get the result of the first aggregate in the last aggregate.
314-
val b = AggregateExpression2(
315-
aggregate.First(evalWithinGroup(regularGroupId, a.toAttribute), Literal(true)),
374+
val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrMap)
375+
val operator = Alias(e.copy(aggregateFunction = af), e.toString)()
376+
377+
// Select the result of the first aggregate in the last aggregate.
378+
val result = AggregateExpression2(
379+
aggregate.First(evalWithinGroup(regularGroupId, operator.toAttribute), Literal(true)),
316380
mode = Complete,
317381
isDistinct = false)
318382

319383
// Some aggregate functions (COUNT) have the special property that they can return a
320384
// non-null result without any input. We need to make sure we return a result in this case.
321-
val c = af.defaultResult match {
322-
case Some(lit) => Coalesce(Seq(b, lit))
323-
case None => b
385+
val resultWithDefault = af.defaultResult match {
386+
case Some(lit) => Coalesce(Seq(result, lit))
387+
case None => result
324388
}
325389

326-
(e, a, c)
390+
// Return a Tuple3 containing:
391+
// i. The original aggregate expression (used for look ups).
392+
// ii. The actual aggregation operator (used in the first aggregate).
393+
// iii. The operator that selects and returns the result (used in the second aggregate).
394+
(e, operator, resultWithDefault)
327395
}
328396

329397
// Construct the regular aggregate input projection only if we need one.

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,23 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
516516
Row(3, 4, 4, 3, null) :: Nil)
517517
}
518518

519+
test("multiple distinct column sets") {
520+
checkAnswer(
521+
sqlContext.sql(
522+
"""
523+
|SELECT
524+
| key,
525+
| count(distinct value1),
526+
| count(distinct value2)
527+
|FROM agg2
528+
|GROUP BY key
529+
""".stripMargin),
530+
Row(null, 3, 3) ::
531+
Row(1, 2, 3) ::
532+
Row(2, 2, 1) ::
533+
Row(3, 0, 1) :: Nil)
534+
}
535+
519536
test("test count") {
520537
checkAnswer(
521538
sqlContext.sql(

0 commit comments

Comments
 (0)