Skip to content

Commit 162a770

Browse files
hvanhovellmarmbrus
authored andcommitted
[SPARK-9241][SQL] Supporting multiple DISTINCT columns (2) - Rewriting Rule
The second PR for SPARK-9241, this adds support for multiple distinct columns to the new aggregation code path. This PR solves the multiple DISTINCT column problem by rewriting these Aggregates into an Expand-Aggregate-Aggregate combination. See the [JIRA ticket](https://issues.apache.org/jira/browse/SPARK-9241) for some information on this. The advantages over the - competing - [first PR](#9280) are: - This can use the faster TungstenAggregate code path. - It is impossible to OOM due to an ```OpenHashSet``` allocating to much memory. However, this will multiply the number of input rows by the number of distinct clauses (plus one), and puts a lot more memory pressure on the aggregation code path itself. The location of this Rule is a bit funny, and should probably change when the old aggregation path is changed. cc yhuai - Could you also tell me where to add tests for this? Author: Herman van Hovell <[email protected]> Closes #9406 from hvanhovell/SPARK-9241-rewriter. (cherry picked from commit 6d0ead3) Signed-off-by: Michael Armbrust <[email protected]>
1 parent 40a5db5 commit 162a770

File tree

6 files changed

+238
-44
lines changed

6 files changed

+238
-44
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,6 @@ case class Count(child: Expression) extends DeclarativeAggregate {
4949
)
5050

5151
override val evaluateExpression = Cast(count, LongType)
52+
53+
override def defaultResult: Option[Literal] = Option(Literal(0L))
5254
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala

Lines changed: 182 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@ package org.apache.spark.sql.catalyst.expressions.aggregate
2020
import org.apache.spark.sql.AnalysisException
2121
import org.apache.spark.sql.catalyst._
2222
import org.apache.spark.sql.catalyst.expressions._
23-
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
24-
import org.apache.spark.sql.types.{StructType, MapType, ArrayType}
23+
import org.apache.spark.sql.catalyst.plans.logical.{Expand, Aggregate, LogicalPlan}
24+
import org.apache.spark.sql.catalyst.rules.Rule
25+
import org.apache.spark.sql.types.{IntegerType, StructType, MapType, ArrayType}
2526

2627
/**
2728
* Utility functions used by the query planner to convert our plan to new aggregation code path.
@@ -41,7 +42,7 @@ object Utils {
4142

4243
private def doConvert(plan: LogicalPlan): Option[Aggregate] = plan match {
4344
case p: Aggregate if supportsGroupingKeySchema(p) =>
44-
val converted = p.transformExpressionsDown {
45+
val converted = MultipleDistinctRewriter.rewrite(p.transformExpressionsDown {
4546
case expressions.Average(child) =>
4647
aggregate.AggregateExpression2(
4748
aggregateFunction = aggregate.Average(child),
@@ -144,7 +145,8 @@ object Utils {
144145
aggregateFunction = aggregate.VarianceSamp(child),
145146
mode = aggregate.Complete,
146147
isDistinct = false)
147-
}
148+
})
149+
148150
// Check if there is any expressions.AggregateExpression1 left.
149151
// If so, we cannot convert this plan.
150152
val hasAggregateExpression1 = converted.aggregateExpressions.exists { expr =>
@@ -156,6 +158,7 @@ object Utils {
156158
}
157159

158160
// Check if there are multiple distinct columns.
161+
// TODO remove this.
159162
val aggregateExpressions = converted.aggregateExpressions.flatMap { expr =>
160163
expr.collect {
161164
case agg: AggregateExpression2 => agg
@@ -213,3 +216,178 @@ object Utils {
213216
case other => None
214217
}
215218
}
219+
220+
/**
221+
* This rule rewrites an aggregate query with multiple distinct clauses into an expanded double
222+
* aggregation in which the regular aggregation expressions and every distinct clause is aggregated
223+
* in a separate group. The results are then combined in a second aggregate.
224+
*
225+
* TODO Expression cannocalization
226+
* TODO Eliminate foldable expressions from distinct clauses.
227+
* TODO This eliminates all distinct expressions. We could safely pass one to the aggregate
228+
* operator. Perhaps this is a good thing? It is much simpler to plan later on...
229+
*/
230+
object MultipleDistinctRewriter extends Rule[LogicalPlan] {
231+
232+
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
233+
case a: Aggregate => rewrite(a)
234+
case p => p
235+
}
236+
237+
def rewrite(a: Aggregate): Aggregate = {
238+
239+
// Collect all aggregate expressions.
240+
val aggExpressions = a.aggregateExpressions.flatMap { e =>
241+
e.collect {
242+
case ae: AggregateExpression2 => ae
243+
}
244+
}
245+
246+
// Extract distinct aggregate expressions.
247+
val distinctAggGroups = aggExpressions
248+
.filter(_.isDistinct)
249+
.groupBy(_.aggregateFunction.children.toSet)
250+
251+
// Only continue to rewrite if there is more than one distinct group.
252+
if (distinctAggGroups.size > 1) {
253+
// Create the attributes for the grouping id and the group by clause.
254+
val gid = new AttributeReference("gid", IntegerType, false)()
255+
val groupByMap = a.groupingExpressions.collect {
256+
case ne: NamedExpression => ne -> ne.toAttribute
257+
case e => e -> new AttributeReference(e.prettyName, e.dataType, e.nullable)()
258+
}
259+
val groupByAttrs = groupByMap.map(_._2)
260+
261+
// Functions used to modify aggregate functions and their inputs.
262+
def evalWithinGroup(id: Literal, e: Expression) = If(EqualTo(gid, id), e, nullify(e))
263+
def patchAggregateFunctionChildren(
264+
af: AggregateFunction2,
265+
id: Literal,
266+
attrs: Map[Expression, Expression]): AggregateFunction2 = {
267+
af.withNewChildren(af.children.map { case afc =>
268+
evalWithinGroup(id, attrs(afc))
269+
}).asInstanceOf[AggregateFunction2]
270+
}
271+
272+
// Setup unique distinct aggregate children.
273+
val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq
274+
val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair).toMap
275+
val distinctAggChildAttrs = distinctAggChildAttrMap.values.toSeq
276+
277+
// Setup expand & aggregate operators for distinct aggregate expressions.
278+
val distinctAggOperatorMap = distinctAggGroups.toSeq.zipWithIndex.map {
279+
case ((group, expressions), i) =>
280+
val id = Literal(i + 1)
281+
282+
// Expand projection
283+
val projection = distinctAggChildren.map {
284+
case e if group.contains(e) => e
285+
case e => nullify(e)
286+
} :+ id
287+
288+
// Final aggregate
289+
val operators = expressions.map { e =>
290+
val af = e.aggregateFunction
291+
val naf = patchAggregateFunctionChildren(af, id, distinctAggChildAttrMap)
292+
(e, e.copy(aggregateFunction = naf, isDistinct = false))
293+
}
294+
295+
(projection, operators)
296+
}
297+
298+
// Setup expand for the 'regular' aggregate expressions.
299+
val regularAggExprs = aggExpressions.filter(!_.isDistinct)
300+
val regularAggChildren = regularAggExprs.flatMap(_.aggregateFunction.children).distinct
301+
val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair).toMap
302+
303+
// Setup aggregates for 'regular' aggregate expressions.
304+
val regularGroupId = Literal(0)
305+
val regularAggOperatorMap = regularAggExprs.map { e =>
306+
// Perform the actual aggregation in the initial aggregate.
307+
val af = patchAggregateFunctionChildren(
308+
e.aggregateFunction,
309+
regularGroupId,
310+
regularAggChildAttrMap)
311+
val a = Alias(e.copy(aggregateFunction = af), e.toString)()
312+
313+
// Get the result of the first aggregate in the last aggregate.
314+
val b = AggregateExpression2(
315+
aggregate.First(evalWithinGroup(regularGroupId, a.toAttribute), Literal(true)),
316+
mode = Complete,
317+
isDistinct = false)
318+
319+
// Some aggregate functions (COUNT) have the special property that they can return a
320+
// non-null result without any input. We need to make sure we return a result in this case.
321+
val c = af.defaultResult match {
322+
case Some(lit) => Coalesce(Seq(b, lit))
323+
case None => b
324+
}
325+
326+
(e, a, c)
327+
}
328+
329+
// Construct the regular aggregate input projection only if we need one.
330+
val regularAggProjection = if (regularAggExprs.nonEmpty) {
331+
Seq(a.groupingExpressions ++
332+
distinctAggChildren.map(nullify) ++
333+
Seq(regularGroupId) ++
334+
regularAggChildren)
335+
} else {
336+
Seq.empty[Seq[Expression]]
337+
}
338+
339+
// Construct the distinct aggregate input projections.
340+
val regularAggNulls = regularAggChildren.map(nullify)
341+
val distinctAggProjections = distinctAggOperatorMap.map {
342+
case (projection, _) =>
343+
a.groupingExpressions ++
344+
projection ++
345+
regularAggNulls
346+
}
347+
348+
// Construct the expand operator.
349+
val expand = Expand(
350+
regularAggProjection ++ distinctAggProjections,
351+
groupByAttrs ++ distinctAggChildAttrs ++ Seq(gid) ++ regularAggChildAttrMap.values.toSeq,
352+
a.child)
353+
354+
// Construct the first aggregate operator. This de-duplicates the all the children of
355+
// distinct operators, and applies the regular aggregate operators.
356+
val firstAggregateGroupBy = groupByAttrs ++ distinctAggChildAttrs :+ gid
357+
val firstAggregate = Aggregate(
358+
firstAggregateGroupBy,
359+
firstAggregateGroupBy ++ regularAggOperatorMap.map(_._2),
360+
expand)
361+
362+
// Construct the second aggregate
363+
val transformations: Map[Expression, Expression] =
364+
(distinctAggOperatorMap.flatMap(_._2) ++
365+
regularAggOperatorMap.map(e => (e._1, e._3))).toMap
366+
367+
val patchedAggExpressions = a.aggregateExpressions.map { e =>
368+
e.transformDown {
369+
case e: Expression =>
370+
// The same GROUP BY clauses can have different forms (different names for instance) in
371+
// the groupBy and aggregate expressions of an aggregate. This makes a map lookup
372+
// tricky. So we do a linear search for a semantically equal group by expression.
373+
groupByMap
374+
.find(ge => e.semanticEquals(ge._1))
375+
.map(_._2)
376+
.getOrElse(transformations.getOrElse(e, e))
377+
}.asInstanceOf[NamedExpression]
378+
}
379+
Aggregate(groupByAttrs, patchedAggExpressions, firstAggregate)
380+
} else {
381+
a
382+
}
383+
}
384+
385+
private def nullify(e: Expression) = Literal.create(null, e.dataType)
386+
387+
private def expressionAttributePair(e: Expression) =
388+
// We are creating a new reference here instead of reusing the attribute in case of a
389+
// NamedExpression. This is done to prevent collisions between distinct and regular aggregate
390+
// children, in this case attribute reuse causes the input of the regular aggregate to bound to
391+
// the (nulled out) input of the distinct aggregate.
392+
e -> new AttributeReference(e.prettyName, e.dataType, true)()
393+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,12 @@ sealed abstract class AggregateFunction2 extends Expression with ImplicitCastInp
133133
*/
134134
def supportsPartial: Boolean = true
135135

136+
/**
137+
* Result of the aggregate function when the input is empty. This is currently only used for the
138+
* proper rewriting of distinct aggregate functions.
139+
*/
140+
def defaultResult: Option[Literal] = None
141+
136142
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String =
137143
throw new UnsupportedOperationException(s"Cannot evaluate expression: $this")
138144
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -200,9 +200,9 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper {
200200
*/
201201
object ColumnPruning extends Rule[LogicalPlan] {
202202
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
203-
case a @ Aggregate(_, _, e @ Expand(_, groupByExprs, _, child))
204-
if (child.outputSet -- AttributeSet(groupByExprs) -- a.references).nonEmpty =>
205-
a.copy(child = e.copy(child = prunedChild(child, AttributeSet(groupByExprs) ++ a.references)))
203+
case a @ Aggregate(_, _, e @ Expand(_, _, child))
204+
if (child.outputSet -- AttributeSet(e.output) -- a.references).nonEmpty =>
205+
a.copy(child = e.copy(child = prunedChild(child, AttributeSet(e.output) ++ a.references)))
206206

207207
// Eliminate attributes that are not needed to calculate the specified aggregates.
208208
case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala

Lines changed: 44 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -235,33 +235,17 @@ case class Window(
235235
projectList ++ windowExpressions.map(_.toAttribute)
236236
}
237237

238-
/**
239-
* Apply the all of the GroupExpressions to every input row, hence we will get
240-
* multiple output rows for a input row.
241-
* @param bitmasks The bitmask set represents the grouping sets
242-
* @param groupByExprs The grouping by expressions
243-
* @param child Child operator
244-
*/
245-
case class Expand(
246-
bitmasks: Seq[Int],
247-
groupByExprs: Seq[Expression],
248-
gid: Attribute,
249-
child: LogicalPlan) extends UnaryNode {
250-
override def statistics: Statistics = {
251-
val sizeInBytes = child.statistics.sizeInBytes * projections.length
252-
Statistics(sizeInBytes = sizeInBytes)
253-
}
254-
255-
val projections: Seq[Seq[Expression]] = expand()
256-
238+
private[sql] object Expand {
257239
/**
258-
* Extract attribute set according to the grouping id
240+
* Extract attribute set according to the grouping id.
241+
*
259242
* @param bitmask bitmask to represent the selected of the attribute sequence
260243
* @param exprs the attributes in sequence
261244
* @return the attributes of non selected specified via bitmask (with the bit set to 1)
262245
*/
263-
private def buildNonSelectExprSet(bitmask: Int, exprs: Seq[Expression])
264-
: OpenHashSet[Expression] = {
246+
private def buildNonSelectExprSet(
247+
bitmask: Int,
248+
exprs: Seq[Expression]): OpenHashSet[Expression] = {
265249
val set = new OpenHashSet[Expression](2)
266250

267251
var bit = exprs.length - 1
@@ -274,18 +258,28 @@ case class Expand(
274258
}
275259

276260
/**
277-
* Create an array of Projections for the child projection, and replace the projections'
278-
* expressions which equal GroupBy expressions with Literal(null), if those expressions
279-
* are not set for this grouping set (according to the bit mask).
261+
* Apply the all of the GroupExpressions to every input row, hence we will get
262+
* multiple output rows for a input row.
263+
*
264+
* @param bitmasks The bitmask set represents the grouping sets
265+
* @param groupByExprs The grouping by expressions
266+
* @param gid Attribute of the grouping id
267+
* @param child Child operator
280268
*/
281-
private[this] def expand(): Seq[Seq[Expression]] = {
282-
val result = new scala.collection.mutable.ArrayBuffer[Seq[Expression]]
283-
284-
bitmasks.foreach { bitmask =>
269+
def apply(
270+
bitmasks: Seq[Int],
271+
groupByExprs: Seq[Expression],
272+
gid: Attribute,
273+
child: LogicalPlan): Expand = {
274+
// Create an array of Projections for the child projection, and replace the projections'
275+
// expressions which equal GroupBy expressions with Literal(null), if those expressions
276+
// are not set for this grouping set (according to the bit mask).
277+
val projections = bitmasks.map { bitmask =>
285278
// get the non selected grouping attributes according to the bit mask
286279
val nonSelectedGroupExprSet = buildNonSelectExprSet(bitmask, groupByExprs)
287280

288-
val substitution = (child.output :+ gid).map(expr => expr transformDown {
281+
(child.output :+ gid).map(expr => expr transformDown {
282+
// TODO this causes a problem when a column is used both for grouping and aggregation.
289283
case x: Expression if nonSelectedGroupExprSet.contains(x) =>
290284
// if the input attribute in the Invalid Grouping Expression set of for this group
291285
// replace it with constant null
@@ -294,15 +288,29 @@ case class Expand(
294288
// replace the groupingId with concrete value (the bit mask)
295289
Literal.create(bitmask, IntegerType)
296290
})
297-
298-
result += substitution
299291
}
300-
301-
result.toSeq
292+
Expand(projections, child.output :+ gid, child)
302293
}
294+
}
303295

304-
override def output: Seq[Attribute] = {
305-
child.output :+ gid
296+
/**
297+
* Apply a number of projections to every input row, hence we will get multiple output rows for
298+
* a input row.
299+
*
300+
* @param projections to apply
301+
* @param output of all projections.
302+
* @param child operator.
303+
*/
304+
case class Expand(
305+
projections: Seq[Seq[Expression]],
306+
output: Seq[Attribute],
307+
child: LogicalPlan) extends UnaryNode {
308+
309+
override def statistics: Statistics = {
310+
// TODO shouldn't we factor in the size of the projection versus the size of the backing child
311+
// row?
312+
val sizeInBytes = child.statistics.sizeInBytes * projections.length
313+
Statistics(sizeInBytes = sizeInBytes)
306314
}
307315
}
308316

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
420420
}
421421
case logical.Filter(condition, child) =>
422422
execution.Filter(condition, planLater(child)) :: Nil
423-
case e @ logical.Expand(_, _, _, child) =>
423+
case e @ logical.Expand(_, _, child) =>
424424
execution.Expand(e.projections, e.output, planLater(child)) :: Nil
425425
case a @ logical.Aggregate(group, agg, child) => {
426426
val useNewAggregation = sqlContext.conf.useSqlAggregate2 && sqlContext.conf.codegenEnabled

0 commit comments

Comments
 (0)