Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -195,19 +195,22 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
converted match {
case None => Nil // Cannot convert to new aggregation code path.
case Some(logical.Aggregate(groupingExpressions, resultExpressions, child)) =>
// Extracts all distinct aggregate expressions from the resultExpressions.
// A single aggregate expression might appear multiple times in resultExpressions.
// In order to avoid evaluating an individual aggregate function multiple times, we'll
// build a set of the distinct aggregate expressions and build a function which can
// be used to re-write expressions so that they reference the single copy of the
// aggregate function which actually gets computed.
val aggregateExpressions = resultExpressions.flatMap { expr =>
expr.collect {
case agg: AggregateExpression2 => agg
}
}.toSet.toSeq
}.distinct
// For those distinct aggregate expressions, we create a map from the
// aggregate function to the corresponding attribute of the function.
val aggregateFunctionMap = aggregateExpressions.map { agg =>
val aggregateFunctionToAttribute = aggregateExpressions.map { agg =>
val aggregateFunction = agg.aggregateFunction
val attribtue = Alias(aggregateFunction, aggregateFunction.toString)().toAttribute
(aggregateFunction, agg.isDistinct) ->
(aggregateFunction -> attribtue)
val attribute = Alias(aggregateFunction, aggregateFunction.toString)().toAttribute
(aggregateFunction, agg.isDistinct) -> attribute
}.toMap

val (functionsWithDistinct, functionsWithoutDistinct) =
Expand All @@ -220,33 +223,67 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
"code path.")
}

val namedGroupingExpressions = groupingExpressions.map {
case ne: NamedExpression => ne -> ne
// If the expression is not a NamedExpressions, we add an alias.
// So, when we generate the result of the operator, the Aggregate Operator
// can directly get the Seq of attributes representing the grouping expressions.
case other =>
val withAlias = Alias(other, other.toString)()
other -> withAlias
}
val groupExpressionMap = namedGroupingExpressions.toMap

// The original `resultExpressions` are a set of expressions which may reference
// aggregate expressions, grouping column values, and constants. When aggregate operator
// emits output rows, we will use `resultExpressions` to generate an output projection
// which takes the grouping columns and final aggregate result buffer as input.
// Thus, we must re-write the result expressions so that their attributes match up with
// the attributes of the final result projection's input row:
val rewrittenResultExpressions = resultExpressions.map { expr =>
expr.transformDown {
case AggregateExpression2(aggregateFunction, _, isDistinct) =>
// The final aggregation buffer's attributes will be `finalAggregationAttributes`,
// so replace each aggregate expression by its corresponding attribute in the set:
aggregateFunctionToAttribute(aggregateFunction, isDistinct)
case expression =>
// Since we're using `namedGroupingAttributes` to extract the grouping key
// columns, we need to replace grouping key expressions with their corresponding
// attributes. We do not rely on the equality check at here since attributes may
// differ cosmetically. Instead, we use semanticEquals.
groupExpressionMap.collectFirst {
case (expr, ne) if expr semanticEquals expression => ne.toAttribute
}.getOrElse(expression)
}.asInstanceOf[NamedExpression]
}

val aggregateOperator =
if (aggregateExpressions.map(_.aggregateFunction).exists(!_.supportsPartial)) {
if (functionsWithDistinct.nonEmpty) {
sys.error("Distinct columns cannot exist in Aggregate operator containing " +
"aggregate functions which don't support partial aggregation.")
} else {
aggregate.Utils.planAggregateWithoutPartial(
groupingExpressions,
namedGroupingExpressions.map(_._2),
aggregateExpressions,
aggregateFunctionMap,
resultExpressions,
aggregateFunctionToAttribute,
rewrittenResultExpressions,
planLater(child))
}
} else if (functionsWithDistinct.isEmpty) {
aggregate.Utils.planAggregateWithoutDistinct(
groupingExpressions,
namedGroupingExpressions.map(_._2),
aggregateExpressions,
aggregateFunctionMap,
resultExpressions,
aggregateFunctionToAttribute,
rewrittenResultExpressions,
planLater(child))
} else {
aggregate.Utils.planAggregateWithOneDistinct(
groupingExpressions,
namedGroupingExpressions.map(_._2),
functionsWithDistinct,
functionsWithoutDistinct,
aggregateFunctionMap,
resultExpressions,
aggregateFunctionToAttribute,
rewrittenResultExpressions,
planLater(child))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ case class TungstenAggregate(
requiredChildDistributionExpressions: Option[Seq[Expression]],
groupingExpressions: Seq[NamedExpression],
nonCompleteAggregateExpressions: Seq[AggregateExpression2],
nonCompleteAggregateAttributes: Seq[Attribute],
completeAggregateExpressions: Seq[AggregateExpression2],
completeAggregateAttributes: Seq[Attribute],
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
extends UnaryNode {
Expand Down Expand Up @@ -77,7 +79,9 @@ case class TungstenAggregate(
new TungstenAggregationIterator(
groupingExpressions,
nonCompleteAggregateExpressions,
nonCompleteAggregateAttributes,
completeAggregateExpressions,
completeAggregateAttributes,
resultExpressions,
newMutableProjection,
child.output,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,12 @@ import org.apache.spark.sql.types.StructType
* @param nonCompleteAggregateExpressions
* [[AggregateExpression2]] containing [[AggregateFunction2]]s with mode [[Partial]],
* [[PartialMerge]], or [[Final]].
* @param nonCompleteAggregateAttributes the attributes of the nonCompleteAggregateExpressions'
* outputs when they are stored in the final aggregation buffer.
* @param completeAggregateExpressions
* [[AggregateExpression2]] containing [[AggregateFunction2]]s with mode [[Complete]].
* @param completeAggregateAttributes the attributes of completeAggregateExpressions' outputs
* when they are stored in the final aggregation buffer.
* @param resultExpressions
* expressions for generating output rows.
* @param newMutableProjection
Expand All @@ -72,7 +76,9 @@ import org.apache.spark.sql.types.StructType
class TungstenAggregationIterator(
groupingExpressions: Seq[NamedExpression],
nonCompleteAggregateExpressions: Seq[AggregateExpression2],
nonCompleteAggregateAttributes: Seq[Attribute],
completeAggregateExpressions: Seq[AggregateExpression2],
completeAggregateAttributes: Seq[Attribute],
resultExpressions: Seq[NamedExpression],
newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
originalInputAttributes: Seq[Attribute],
Expand Down Expand Up @@ -280,17 +286,25 @@ class TungstenAggregationIterator(
// resultExpressions.
case (Some(Final), None) | (Some(Final) | None, Some(Complete)) =>
val joinedRow = new JoinedRow()
val evalExpressions = allAggregateFunctions.map {
case ae: DeclarativeAggregate => ae.evaluateExpression
// case agg: AggregateFunction2 => Literal.create(null, agg.dataType)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will we uncomment it? Or, we will use NoOp?

}
val expressionAggEvalProjection = UnsafeProjection.create(evalExpressions, bufferAttributes)
// These are the attributes of the row produced by `expressionAggEvalProjection`
val aggregateResultSchema = nonCompleteAggregateAttributes ++ completeAggregateAttributes
val resultProjection =
UnsafeProjection.create(resultExpressions, groupingAttributes ++ bufferAttributes)
UnsafeProjection.create(resultExpressions, groupingAttributes ++ aggregateResultSchema)

(currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => {
resultProjection(joinedRow(currentGroupingKey, currentBuffer))
// Generate results for all expression-based aggregate functions.
val aggregateResult = expressionAggEvalProjection.apply(currentBuffer)
resultProjection(joinedRow(currentGroupingKey, aggregateResult))
}

// Grouping-only: a output row is generated from values of grouping expressions.
case (None, None) =>
val resultProjection =
UnsafeProjection.create(resultExpressions, groupingAttributes)
val resultProjection = UnsafeProjection.create(resultExpressions, groupingAttributes)

(currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => {
resultProjection(currentGroupingKey)
Expand Down
Loading