-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-9241][SQL] Supporting multiple DISTINCT columns (2) - Rewriting Rule #9406
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c9d0c1d
733fced
d85462d
7b5369c
d626c20
ece657b
9be5b9d
d3bdb2b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,8 +20,9 @@ package org.apache.spark.sql.catalyst.expressions.aggregate | |
| import org.apache.spark.sql.AnalysisException | ||
| import org.apache.spark.sql.catalyst._ | ||
| import org.apache.spark.sql.catalyst.expressions._ | ||
| import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} | ||
| import org.apache.spark.sql.types.{StructType, MapType, ArrayType} | ||
| import org.apache.spark.sql.catalyst.plans.logical.{Expand, Aggregate, LogicalPlan} | ||
| import org.apache.spark.sql.catalyst.rules.Rule | ||
| import org.apache.spark.sql.types.{IntegerType, StructType, MapType, ArrayType} | ||
|
|
||
| /** | ||
| * Utility functions used by the query planner to convert our plan to new aggregation code path. | ||
|
|
@@ -41,7 +42,7 @@ object Utils { | |
|
|
||
| private def doConvert(plan: LogicalPlan): Option[Aggregate] = plan match { | ||
| case p: Aggregate if supportsGroupingKeySchema(p) => | ||
| val converted = p.transformExpressionsDown { | ||
| val converted = MultipleDistinctRewriter.rewrite(p.transformExpressionsDown { | ||
| case expressions.Average(child) => | ||
| aggregate.AggregateExpression2( | ||
| aggregateFunction = aggregate.Average(child), | ||
|
|
@@ -144,7 +145,8 @@ object Utils { | |
| aggregateFunction = aggregate.VarianceSamp(child), | ||
| mode = aggregate.Complete, | ||
| isDistinct = false) | ||
| } | ||
| }) | ||
|
|
||
| // Check if there is any expressions.AggregateExpression1 left. | ||
| // If so, we cannot convert this plan. | ||
| val hasAggregateExpression1 = converted.aggregateExpressions.exists { expr => | ||
|
|
@@ -156,6 +158,7 @@ object Utils { | |
| } | ||
|
|
||
| // Check if there are multiple distinct columns. | ||
| // TODO remove this. | ||
| val aggregateExpressions = converted.aggregateExpressions.flatMap { expr => | ||
| expr.collect { | ||
| case agg: AggregateExpression2 => agg | ||
|
|
@@ -213,3 +216,178 @@ object Utils { | |
| case other => None | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * This rule rewrites an aggregate query with multiple distinct clauses into an expanded double | ||
| * aggregation in which the regular aggregation expressions and every distinct clause is aggregated | ||
| * in a separate group. The results are then combined in a second aggregate. | ||
| * | ||
| * TODO Expression cannocalization | ||
| * TODO Eliminate foldable expressions from distinct clauses. | ||
| * TODO This eliminates all distinct expressions. We could safely pass one to the aggregate | ||
| * operator. Perhaps this is a good thing? It is much simpler to plan later on... | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just making sure I understand, these are all optimizations not correctness concerns?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These are all optimizations. The last one is perhaps a question for @yhuai: we have a choice to rewrite all distinct expressions.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, we can use this path to handle all cases. If I understand correctly, this rewriting approach will first create two logical Aggregate operators and then we shuffle data twice. Our current planning rule for a single distinct agg will shuffle data once, which can be bad if we do not have group by clause (because we will have a single reducer). To make the ideal decision, we need to know the statistics of grouping columns and distinct column. However, for the cases that we have a single distinct column and we do not have a group by clause, I feel your rewriting approach should be strictly better. What do you think? |
||
| */ | ||
| object MultipleDistinctRewriter extends Rule[LogicalPlan] { | ||
|
|
||
| def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { | ||
| case a: Aggregate => rewrite(a) | ||
| case p => p | ||
| } | ||
|
|
||
| def rewrite(a: Aggregate): Aggregate = { | ||
|
|
||
| // Collect all aggregate expressions. | ||
| val aggExpressions = a.aggregateExpressions.flatMap { e => | ||
| e.collect { | ||
| case ae: AggregateExpression2 => ae | ||
| } | ||
| } | ||
|
|
||
| // Extract distinct aggregate expressions. | ||
| val distinctAggGroups = aggExpressions | ||
| .filter(_.isDistinct) | ||
| .groupBy(_.aggregateFunction.children.toSet) | ||
|
|
||
| // Only continue to rewrite if there is more than one distinct group. | ||
| if (distinctAggGroups.size > 1) { | ||
| // Create the attributes for the grouping id and the group by clause. | ||
| val gid = new AttributeReference("gid", IntegerType, false)() | ||
| val groupByMap = a.groupingExpressions.collect { | ||
| case ne: NamedExpression => ne -> ne.toAttribute | ||
| case e => e -> new AttributeReference(e.prettyName, e.dataType, e.nullable)() | ||
| } | ||
| val groupByAttrs = groupByMap.map(_._2) | ||
|
|
||
| // Functions used to modify aggregate functions and their inputs. | ||
| def evalWithinGroup(id: Literal, e: Expression) = If(EqualTo(gid, id), e, nullify(e)) | ||
| def patchAggregateFunctionChildren( | ||
| af: AggregateFunction2, | ||
| id: Literal, | ||
| attrs: Map[Expression, Expression]): AggregateFunction2 = { | ||
| af.withNewChildren(af.children.map { case afc => | ||
| evalWithinGroup(id, attrs(afc)) | ||
| }).asInstanceOf[AggregateFunction2] | ||
| } | ||
|
|
||
| // Setup unique distinct aggregate children. | ||
| val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq | ||
| val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair).toMap | ||
| val distinctAggChildAttrs = distinctAggChildAttrMap.values.toSeq | ||
|
|
||
| // Setup expand & aggregate operators for distinct aggregate expressions. | ||
| val distinctAggOperatorMap = distinctAggGroups.toSeq.zipWithIndex.map { | ||
| case ((group, expressions), i) => | ||
| val id = Literal(i + 1) | ||
|
|
||
| // Expand projection | ||
| val projection = distinctAggChildren.map { | ||
| case e if group.contains(e) => e | ||
| case e => nullify(e) | ||
| } :+ id | ||
|
|
||
| // Final aggregate | ||
| val operators = expressions.map { e => | ||
| val af = e.aggregateFunction | ||
| val naf = patchAggregateFunctionChildren(af, id, distinctAggChildAttrMap) | ||
| (e, e.copy(aggregateFunction = naf, isDistinct = false)) | ||
| } | ||
|
|
||
| (projection, operators) | ||
| } | ||
|
|
||
| // Setup expand for the 'regular' aggregate expressions. | ||
| val regularAggExprs = aggExpressions.filter(!_.isDistinct) | ||
| val regularAggChildren = regularAggExprs.flatMap(_.aggregateFunction.children).distinct | ||
| val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair).toMap | ||
|
|
||
| // Setup aggregates for 'regular' aggregate expressions. | ||
| val regularGroupId = Literal(0) | ||
| val regularAggOperatorMap = regularAggExprs.map { e => | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Comment on what each tuple element is, or maybe even use a case class?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll add documentation in a follow-up PR. |
||
| // Perform the actual aggregation in the initial aggregate. | ||
| val af = patchAggregateFunctionChildren( | ||
| e.aggregateFunction, | ||
| regularGroupId, | ||
| regularAggChildAttrMap) | ||
| val a = Alias(e.copy(aggregateFunction = af), e.toString)() | ||
|
|
||
| // Get the result of the first aggregate in the last aggregate. | ||
| val b = AggregateExpression2( | ||
| aggregate.First(evalWithinGroup(regularGroupId, a.toAttribute), Literal(true)), | ||
| mode = Complete, | ||
| isDistinct = false) | ||
|
|
||
| // Some aggregate functions (COUNT) have the special property that they can return a | ||
| // non-null result without any input. We need to make sure we return a result in this case. | ||
| val c = af.defaultResult match { | ||
| case Some(lit) => Coalesce(Seq(b, lit)) | ||
| case None => b | ||
| } | ||
|
|
||
| (e, a, c) | ||
| } | ||
|
|
||
| // Construct the regular aggregate input projection only if we need one. | ||
| val regularAggProjection = if (regularAggExprs.nonEmpty) { | ||
| Seq(a.groupingExpressions ++ | ||
| distinctAggChildren.map(nullify) ++ | ||
| Seq(regularGroupId) ++ | ||
| regularAggChildren) | ||
| } else { | ||
| Seq.empty[Seq[Expression]] | ||
| } | ||
|
|
||
| // Construct the distinct aggregate input projections. | ||
| val regularAggNulls = regularAggChildren.map(nullify) | ||
| val distinctAggProjections = distinctAggOperatorMap.map { | ||
| case (projection, _) => | ||
| a.groupingExpressions ++ | ||
| projection ++ | ||
| regularAggNulls | ||
| } | ||
|
|
||
| // Construct the expand operator. | ||
| val expand = Expand( | ||
| regularAggProjection ++ distinctAggProjections, | ||
| groupByAttrs ++ distinctAggChildAttrs ++ Seq(gid) ++ regularAggChildAttrMap.values.toSeq, | ||
| a.child) | ||
|
|
||
| // Construct the first aggregate operator. This de-duplicates the all the children of | ||
| // distinct operators, and applies the regular aggregate operators. | ||
| val firstAggregateGroupBy = groupByAttrs ++ distinctAggChildAttrs :+ gid | ||
| val firstAggregate = Aggregate( | ||
| firstAggregateGroupBy, | ||
| firstAggregateGroupBy ++ regularAggOperatorMap.map(_._2), | ||
| expand) | ||
|
|
||
| // Construct the second aggregate | ||
| val transformations: Map[Expression, Expression] = | ||
| (distinctAggOperatorMap.flatMap(_._2) ++ | ||
| regularAggOperatorMap.map(e => (e._1, e._3))).toMap | ||
|
|
||
| val patchedAggExpressions = a.aggregateExpressions.map { e => | ||
| e.transformDown { | ||
| case e: Expression => | ||
| // The same GROUP BY clauses can have different forms (different names for instance) in | ||
| // the groupBy and aggregate expressions of an aggregate. This makes a map lookup | ||
| // tricky. So we do a linear search for a semantically equal group by expression. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We've talked about adding an |
||
| groupByMap | ||
| .find(ge => e.semanticEquals(ge._1)) | ||
| .map(_._2) | ||
| .getOrElse(transformations.getOrElse(e, e)) | ||
| }.asInstanceOf[NamedExpression] | ||
| } | ||
| Aggregate(groupByAttrs, patchedAggExpressions, firstAggregate) | ||
| } else { | ||
| a | ||
| } | ||
| } | ||
|
|
||
| private def nullify(e: Expression) = Literal.create(null, e.dataType) | ||
|
|
||
| private def expressionAttributePair(e: Expression) = | ||
| // We are creating a new reference here instead of reusing the attribute in case of a | ||
| // NamedExpression. This is done to prevent collisions between distinct and regular aggregate | ||
| // children, in this case attribute reuse causes the input of the regular aggregate to bound to | ||
| // the (nulled out) input of the distinct aggregate. | ||
| e -> new AttributeReference(e.prettyName, e.dataType, true)() | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be really helpful if there was an example of what this rewrite looks like here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll add an example in the follow-up PR.