@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.plans._
2424import org .apache .spark .sql .catalyst .plans .logical .{BroadcastHint , LogicalPlan }
2525import org .apache .spark .sql .catalyst .plans .physical ._
2626import org .apache .spark .sql .columnar .{InMemoryColumnarTableScan , InMemoryRelation }
27+ import org .apache .spark .sql .execution .aggregate2 .Aggregate2Sort
2728import org .apache .spark .sql .execution .{DescribeCommand => RunnableDescribeCommand }
2829import org .apache .spark .sql .parquet ._
2930import org .apache .spark .sql .sources .{CreateTableUsing , CreateTempTableUsing , DescribeCommand => LogicalDescribeCommand , _ }
@@ -186,67 +187,71 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
186187 exprs.flatMap(_.collect { case a : AggregateExpression => a })
187188 }
188189
190+ /**
191+ * Used to plan the aggregate operator for expressions based on the AggregateFunction2 interface.
192+ */
189193 object AggregateOperator2 extends Strategy {
190194 def apply (plan : LogicalPlan ): Seq [SparkPlan ] = plan match {
191195 case logical.Aggregate (groupingExpressions, resultExpressions, child)
192196 if sqlContext.conf.useSqlAggregate2 =>
193- // 1. Extracts all aggregate expressions.
197+ // 1. Extracts all distinct aggregate expressions from the resultExpressions .
194198 val aggregateExpressions = resultExpressions.flatMap { expr =>
195199 expr.collect {
196200 case agg : AggregateExpression2 => agg
197201 }
198202 }.toSet.toSeq
199- val aggregateExpressionMap = aggregateExpressions.zipWithIndex.map {
200- case (agg, index) =>
201- agg.aggregateFunction -> Alias (agg, s " _agg $index" )().toAttribute
203+ // For those distinct aggregate expressions, we create a map from the aggregate function
204+ // to the corresponding attribute of the function.
205+ val aggregateFunctionMap = aggregateExpressions.map { agg =>
206+ val aggregateFunction = agg.aggregateFunction
207+ aggregateFunction -> Alias (aggregateFunction, aggregateFunction.toString)().toAttribute
202208 }.toMap
203209
204- // 2. Create Pre-shuffle Aggregate Operator
205- val namedGroupingExpressions = groupingExpressions.zipWithIndex.map {
206- case (ne : NamedExpression , index) => ne
207- case (other, index) => Alias (other, s " _groupingExpr $index" )()
210+ // 2. Create an Aggregate Operator for partial aggregations.
211+ val namedGroupingExpressions = groupingExpressions.map {
212+ case ne : NamedExpression => ne
213+ // If the expression is not a NamedExpressions, we add an alias.
214+ // So, when we generate the result of the operator, the Aggregate Operator
215+ // can directly get the Seq of attributes representing the grouping expressions.
216+ case other => Alias (other, other.toString)()
208217 }
209218 val namedGroupingAttributes = namedGroupingExpressions.map(_.toAttribute)
210- val preShuffleAggregateExpressions = aggregateExpressions.map {
219+ val partialAggregateExpressions = aggregateExpressions.map {
211220 case AggregateExpression2 (aggregateFunction, mode, isDistinct) =>
212221 AggregateExpression2 (aggregateFunction, Partial , isDistinct)
213222 }
214- val preShuffleAggregateAttributes = preShuffleAggregateExpressions.zipWithIndex.flatMap {
215- case (AggregateExpression2 (aggregateFunction, Partial , isDistinct), index) =>
216- aggregateFunction.bufferValueDataTypes.map {
217- case StructField (name, dataType, nullable, metadata) =>
218- AttributeReference (s " _partialAgg ${index}_ ${name}" , dataType, nullable, metadata)()
219- }
223+ val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg =>
224+ agg.bufferAttributes
220225 }
221226 val partialAggregate =
222227 Aggregate2Sort (
223228 true ,
224229 namedGroupingExpressions,
225- preShuffleAggregateExpressions ,
226- preShuffleAggregateAttributes ,
227- namedGroupingAttributes ++ preShuffleAggregateAttributes ,
230+ partialAggregateExpressions ,
231+ partialAggregateAttributes ,
232+ namedGroupingAttributes ++ partialAggregateAttributes ,
228233 planLater(child))
229234
230- // 3. Create post-shuffle Aggregate Operator.
231- val postShuffleAggregateExpressions = aggregateExpressions.map {
235+ // 3. Create an Aggregate Operator for final aggregations .
236+ val finalAggregateExpressions = aggregateExpressions.map {
232237 case AggregateExpression2 (aggregateFunction, mode, isDistinct) =>
233238 AggregateExpression2 (aggregateFunction, Final , isDistinct)
234239 }
235- val postShuffleAggregateAttributes =
236- postShuffleAggregateExpressions .map {
237- expr => aggregateExpressionMap (expr.aggregateFunction)
240+ val finalAggregateAttributes =
241+ finalAggregateExpressions .map {
242+ expr => aggregateFunctionMap (expr.aggregateFunction)
238243 }
239244 val rewrittenResultExpressions = resultExpressions.map { expr =>
240245 expr.transform {
241246 case agg : AggregateExpression2 =>
242- aggregateExpressionMap (agg.aggregateFunction).toAttribute
247+ aggregateFunctionMap (agg.aggregateFunction).toAttribute
243248 }.asInstanceOf [NamedExpression ]
244249 }
245250 val finalAggregate = Aggregate2Sort (
246251 false ,
247252 namedGroupingAttributes,
248- postShuffleAggregateExpressions ,
249- postShuffleAggregateAttributes ,
253+ finalAggregateExpressions ,
254+ finalAggregateAttributes ,
250255 rewrittenResultExpressions,
251256 partialAggregate)
252257
0 commit comments