Skip to content

Commit 2816c89

Browse files
JoshRosenyhuai
authored andcommitted
[SPARK-10988] [SQL] Reduce duplication in Aggregate2's expression rewriting logic
In `aggregate/utils.scala`, there is a substantial amount of duplication in the expression-rewriting logic. As a prerequisite to supporting imperative aggregate functions in `TungstenAggregate`, this patch refactors this file so that the same expression-rewriting logic is used for both `SortAggregate` and `TungstenAggregate`. In order to allow both operators to use the same rewriting logic, `TungstenAggregationIterator. generateResultProjection()` has been updated so that it first evaluates all declarative aggregate functions' `evaluateExpression`s and writes the results into a temporary buffer, and then uses this temporary buffer and the grouping expressions to evaluate the final resultExpressions. This matches the logic in SortAggregateIterator, where this two-pass approach is necessary in order to support imperative aggregates. If this change turns out to cause performance regressions, then we can look into re-implementing the single-pass evaluation in a cleaner way as part of a followup patch. Since the rewriting logic is now shared across both operators, this patch also extracts that logic and places it in `SparkStrategies`. This makes the rewriting logic a bit easier to follow, I think. Author: Josh Rosen <[email protected]> Closes #9015 from JoshRosen/SPARK-10988.
1 parent 9e66a53 commit 2816c89

File tree

5 files changed

+143
-196
lines changed

5 files changed

+143
-196
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 52 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -195,19 +195,22 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
195195
converted match {
196196
case None => Nil // Cannot convert to new aggregation code path.
197197
case Some(logical.Aggregate(groupingExpressions, resultExpressions, child)) =>
198-
// Extracts all distinct aggregate expressions from the resultExpressions.
198+
// A single aggregate expression might appear multiple times in resultExpressions.
199+
// In order to avoid evaluating an individual aggregate function multiple times, we'll
200+
// build a set of the distinct aggregate expressions and build a function which can
201+
// be used to re-write expressions so that they reference the single copy of the
202+
// aggregate function which actually gets computed.
199203
val aggregateExpressions = resultExpressions.flatMap { expr =>
200204
expr.collect {
201205
case agg: AggregateExpression2 => agg
202206
}
203-
}.toSet.toSeq
207+
}.distinct
204208
// For those distinct aggregate expressions, we create a map from the
205209
// aggregate function to the corresponding attribute of the function.
206-
val aggregateFunctionMap = aggregateExpressions.map { agg =>
210+
val aggregateFunctionToAttribute = aggregateExpressions.map { agg =>
207211
val aggregateFunction = agg.aggregateFunction
208-
val attribtue = Alias(aggregateFunction, aggregateFunction.toString)().toAttribute
209-
(aggregateFunction, agg.isDistinct) ->
210-
(aggregateFunction -> attribtue)
212+
val attribute = Alias(aggregateFunction, aggregateFunction.toString)().toAttribute
213+
(aggregateFunction, agg.isDistinct) -> attribute
211214
}.toMap
212215

213216
val (functionsWithDistinct, functionsWithoutDistinct) =
@@ -220,33 +223,67 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
220223
"code path.")
221224
}
222225

226+
val namedGroupingExpressions = groupingExpressions.map {
227+
case ne: NamedExpression => ne -> ne
228+
// If the expression is not a NamedExpressions, we add an alias.
229+
// So, when we generate the result of the operator, the Aggregate Operator
230+
// can directly get the Seq of attributes representing the grouping expressions.
231+
case other =>
232+
val withAlias = Alias(other, other.toString)()
233+
other -> withAlias
234+
}
235+
val groupExpressionMap = namedGroupingExpressions.toMap
236+
237+
// The original `resultExpressions` are a set of expressions which may reference
238+
// aggregate expressions, grouping column values, and constants. When aggregate operator
239+
// emits output rows, we will use `resultExpressions` to generate an output projection
240+
// which takes the grouping columns and final aggregate result buffer as input.
241+
// Thus, we must re-write the result expressions so that their attributes match up with
242+
// the attributes of the final result projection's input row:
243+
val rewrittenResultExpressions = resultExpressions.map { expr =>
244+
expr.transformDown {
245+
case AggregateExpression2(aggregateFunction, _, isDistinct) =>
246+
// The final aggregation buffer's attributes will be `finalAggregationAttributes`,
247+
// so replace each aggregate expression by its corresponding attribute in the set:
248+
aggregateFunctionToAttribute(aggregateFunction, isDistinct)
249+
case expression =>
250+
// Since we're using `namedGroupingAttributes` to extract the grouping key
251+
// columns, we need to replace grouping key expressions with their corresponding
252+
// attributes. We do not rely on the equality check at here since attributes may
253+
// differ cosmetically. Instead, we use semanticEquals.
254+
groupExpressionMap.collectFirst {
255+
case (expr, ne) if expr semanticEquals expression => ne.toAttribute
256+
}.getOrElse(expression)
257+
}.asInstanceOf[NamedExpression]
258+
}
259+
223260
val aggregateOperator =
224261
if (aggregateExpressions.map(_.aggregateFunction).exists(!_.supportsPartial)) {
225262
if (functionsWithDistinct.nonEmpty) {
226263
sys.error("Distinct columns cannot exist in Aggregate operator containing " +
227264
"aggregate functions which don't support partial aggregation.")
228265
} else {
229266
aggregate.Utils.planAggregateWithoutPartial(
230-
groupingExpressions,
267+
namedGroupingExpressions.map(_._2),
231268
aggregateExpressions,
232-
aggregateFunctionMap,
233-
resultExpressions,
269+
aggregateFunctionToAttribute,
270+
rewrittenResultExpressions,
234271
planLater(child))
235272
}
236273
} else if (functionsWithDistinct.isEmpty) {
237274
aggregate.Utils.planAggregateWithoutDistinct(
238-
groupingExpressions,
275+
namedGroupingExpressions.map(_._2),
239276
aggregateExpressions,
240-
aggregateFunctionMap,
241-
resultExpressions,
277+
aggregateFunctionToAttribute,
278+
rewrittenResultExpressions,
242279
planLater(child))
243280
} else {
244281
aggregate.Utils.planAggregateWithOneDistinct(
245-
groupingExpressions,
282+
namedGroupingExpressions.map(_._2),
246283
functionsWithDistinct,
247284
functionsWithoutDistinct,
248-
aggregateFunctionMap,
249-
resultExpressions,
285+
aggregateFunctionToAttribute,
286+
rewrittenResultExpressions,
250287
planLater(child))
251288
}
252289

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@ case class TungstenAggregate(
3131
requiredChildDistributionExpressions: Option[Seq[Expression]],
3232
groupingExpressions: Seq[NamedExpression],
3333
nonCompleteAggregateExpressions: Seq[AggregateExpression2],
34+
nonCompleteAggregateAttributes: Seq[Attribute],
3435
completeAggregateExpressions: Seq[AggregateExpression2],
36+
completeAggregateAttributes: Seq[Attribute],
3537
resultExpressions: Seq[NamedExpression],
3638
child: SparkPlan)
3739
extends UnaryNode {
@@ -77,7 +79,9 @@ case class TungstenAggregate(
7779
new TungstenAggregationIterator(
7880
groupingExpressions,
7981
nonCompleteAggregateExpressions,
82+
nonCompleteAggregateAttributes,
8083
completeAggregateExpressions,
84+
completeAggregateAttributes,
8185
resultExpressions,
8286
newMutableProjection,
8387
child.output,

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,12 @@ import org.apache.spark.sql.types.StructType
6060
* @param nonCompleteAggregateExpressions
6161
* [[AggregateExpression2]] containing [[AggregateFunction2]]s with mode [[Partial]],
6262
* [[PartialMerge]], or [[Final]].
63+
* @param nonCompleteAggregateAttributes the attributes of the nonCompleteAggregateExpressions'
64+
* outputs when they are stored in the final aggregation buffer.
6365
* @param completeAggregateExpressions
6466
* [[AggregateExpression2]] containing [[AggregateFunction2]]s with mode [[Complete]].
67+
* @param completeAggregateAttributes the attributes of completeAggregateExpressions' outputs
68+
* when they are stored in the final aggregation buffer.
6569
* @param resultExpressions
6670
* expressions for generating output rows.
6771
* @param newMutableProjection
@@ -72,7 +76,9 @@ import org.apache.spark.sql.types.StructType
7276
class TungstenAggregationIterator(
7377
groupingExpressions: Seq[NamedExpression],
7478
nonCompleteAggregateExpressions: Seq[AggregateExpression2],
79+
nonCompleteAggregateAttributes: Seq[Attribute],
7580
completeAggregateExpressions: Seq[AggregateExpression2],
81+
completeAggregateAttributes: Seq[Attribute],
7682
resultExpressions: Seq[NamedExpression],
7783
newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
7884
originalInputAttributes: Seq[Attribute],
@@ -280,17 +286,25 @@ class TungstenAggregationIterator(
280286
// resultExpressions.
281287
case (Some(Final), None) | (Some(Final) | None, Some(Complete)) =>
282288
val joinedRow = new JoinedRow()
289+
val evalExpressions = allAggregateFunctions.map {
290+
case ae: DeclarativeAggregate => ae.evaluateExpression
291+
// case agg: AggregateFunction2 => Literal.create(null, agg.dataType)
292+
}
293+
val expressionAggEvalProjection = UnsafeProjection.create(evalExpressions, bufferAttributes)
294+
// These are the attributes of the row produced by `expressionAggEvalProjection`
295+
val aggregateResultSchema = nonCompleteAggregateAttributes ++ completeAggregateAttributes
283296
val resultProjection =
284-
UnsafeProjection.create(resultExpressions, groupingAttributes ++ bufferAttributes)
297+
UnsafeProjection.create(resultExpressions, groupingAttributes ++ aggregateResultSchema)
285298

286299
(currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => {
287-
resultProjection(joinedRow(currentGroupingKey, currentBuffer))
300+
// Generate results for all expression-based aggregate functions.
301+
val aggregateResult = expressionAggEvalProjection.apply(currentBuffer)
302+
resultProjection(joinedRow(currentGroupingKey, aggregateResult))
288303
}
289304

290305
// Grouping-only: a output row is generated from values of grouping expressions.
291306
case (None, None) =>
292-
val resultProjection =
293-
UnsafeProjection.create(resultExpressions, groupingAttributes)
307+
val resultProjection = UnsafeProjection.create(resultExpressions, groupingAttributes)
294308

295309
(currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => {
296310
resultProjection(currentGroupingKey)

0 commit comments

Comments
 (0)