Skip to content

Commit c3614d7

Browse files
committed
Handle single distinct column.
1 parent 68b8ee9 commit c3614d7

File tree

8 files changed

+476
-164
lines changed

8 files changed

+476
-164
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@ private[sql] case object Final extends AggregateMode
5858
*/
5959
private[sql] case object Complete extends AggregateMode
6060

61+
/**
62+
* A place holder expressions used in code-gen, it does not change the corresponding value
63+
* in the row.
64+
*/
6165
private[sql] case object NoOp extends Expression with Unevaluable {
6266
override def nullable: Boolean = true
6367
override def eval(input: InternalRow): Any = {

sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -281,8 +281,15 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[
281281
}
282282

283283
def addSortIfNecessary(child: SparkPlan): SparkPlan = {
284-
if (rowOrdering.nonEmpty && child.outputOrdering != rowOrdering) {
285-
sqlContext.planner.BasicOperators.getSortOperator(rowOrdering, global = false, child)
284+
285+
if (rowOrdering.nonEmpty) {
286+
// If child.outputOrdering is [a, b] and rowOrdering is [a], we do not need to sort.
287+
val minSize = Seq(rowOrdering.size, child.outputOrdering.size).min
288+
if (minSize == 0 || rowOrdering.take(minSize) != child.outputOrdering.take(minSize)) {
289+
sqlContext.planner.BasicOperators.getSortOperator(rowOrdering, global = false, child)
290+
} else {
291+
child
292+
}
286293
} else {
287294
child
288295
}

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
203203
private def planAggregateWithoutDistinct(
204204
groupingExpressions: Seq[Expression],
205205
aggregateExpressions: Seq[AggregateExpression2],
206-
aggregateFunctionMap: Map[AggregateFunction2, Attribute],
206+
aggregateFunctionMap: Map[(AggregateFunction2, Boolean), Attribute],
207207
resultExpressions: Seq[NamedExpression],
208208
child: SparkPlan): Seq[SparkPlan] = {
209209
// 1. Create an Aggregate Operator for partial aggregations.
@@ -241,12 +241,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
241241
}
242242
val finalAggregateAttributes =
243243
finalAggregateExpressions.map {
244-
expr => aggregateFunctionMap(expr.aggregateFunction)
244+
expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)
245245
}
246246
val rewrittenResultExpressions = resultExpressions.map { expr =>
247247
expr.transform {
248248
case agg: AggregateExpression2 =>
249-
aggregateFunctionMap(agg.aggregateFunction).toAttribute
249+
aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct).toAttribute
250250
case expression if groupExpressionMap.contains(expression) =>
251251
groupExpressionMap(expression).toAttribute
252252
}.asInstanceOf[NamedExpression]
@@ -266,7 +266,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
266266
groupingExpressions: Seq[Expression],
267267
functionsWithDistinct: Seq[AggregateExpression2],
268268
functionsWithoutDistinct: Seq[AggregateExpression2],
269-
aggregateFunctionMap: Map[AggregateFunction2, Attribute],
269+
aggregateFunctionMap: Map[(AggregateFunction2, Boolean), Attribute],
270270
resultExpressions: Seq[NamedExpression],
271271
child: SparkPlan): Seq[SparkPlan] = {
272272

@@ -306,7 +306,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
306306
val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg =>
307307
agg.aggregateFunction.bufferAttributes
308308
}
309-
println("namedDistinctColumnExpressions " + namedDistinctColumnExpressions)
310309
val partialAggregate =
311310
Aggregate2Sort(
312311
None: Option[Seq[Expression]],
@@ -323,7 +322,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
323322
}
324323
val partialMergeAggregateAttributes =
325324
partialMergeAggregateExpressions.map {
326-
expr => aggregateFunctionMap(expr.aggregateFunction)
325+
expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)
327326
}
328327
val partialMergeAggregate =
329328
Aggregate2Sort(
@@ -336,34 +335,41 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
336335

337336
// 3. Create an Aggregate Operator for partial merge aggregations.
338337
val finalAggregateExpressions = functionsWithoutDistinct.map {
339-
Need to replace the children to distinctColumnAttributes
340338
case AggregateExpression2(aggregateFunction, mode, _) =>
341339
AggregateExpression2(aggregateFunction, Final, false)
342340
}
343341
val finalAggregateAttributes =
344342
finalAggregateExpressions.map {
345-
expr => aggregateFunctionMap(expr.aggregateFunction)
346-
}
347-
val completeAggregateExpressions = functionsWithDistinct.map {
348-
case AggregateExpression2(aggregateFunction, mode, _) =>
349-
AggregateExpression2(aggregateFunction, Complete, false)
350-
}
351-
val completeAggregateAttributes =
352-
completeAggregateExpressions.map {
353-
expr => aggregateFunctionMap(expr.aggregateFunction)
343+
expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)
354344
}
345+
val (completeAggregateExpressions, completeAggregateAttributes) = functionsWithDistinct.map {
346+
// Children of an AggregateFunction with DISTINCT keyword has already
347+
// been evaluated. At here, we need to replace original children
348+
// to AttributeReferences.
349+
case agg @ AggregateExpression2(aggregateFunction, mode, isDistinct) =>
350+
val rewrittenAggregateFunction = aggregateFunction.transformDown {
351+
case expr if distinctColumnExpressionMap.contains(expr) =>
352+
distinctColumnExpressionMap(expr).toAttribute
353+
}.asInstanceOf[AggregateFunction2]
354+
// We rewrite the aggregate function to a non-distinct aggregation because
355+
// its input will have distinct arguments.
356+
val rewrittenAggregateExpression =
357+
AggregateExpression2(rewrittenAggregateFunction, Complete, false)
358+
359+
val aggregateFunctionAttribute = aggregateFunctionMap(agg.aggregateFunction, isDistinct)
360+
(rewrittenAggregateExpression -> aggregateFunctionAttribute)
361+
}.unzip
355362

356363
val rewrittenResultExpressions = resultExpressions.map { expr =>
357364
expr.transform {
358365
case agg: AggregateExpression2 =>
359-
aggregateFunctionMap(agg.aggregateFunction).toAttribute
366+
aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct).toAttribute
360367
case expression if groupExpressionMap.contains(expression) =>
361368
groupExpressionMap(expression).toAttribute
362-
case expression if distinctColumnExpressionMap.contains(expression) =>
363-
distinctColumnExpressionMap(expression).toAttribute
364369
}.asInstanceOf[NamedExpression]
365370
}
366371
val finalAndCompleteAggregate = FinalAndCompleteAggregate2Sort(
372+
namedGroupingAttributes ++ distinctColumnAttributes,
367373
namedGroupingAttributes,
368374
finalAggregateExpressions,
369375
finalAggregateAttributes,
@@ -378,7 +384,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
378384
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
379385
case logical.Aggregate(groupingExpressions, resultExpressions, child)
380386
if sqlContext.conf.useSqlAggregate2 =>
381-
// 1. Extracts all distinct aggregate expressions from the resultExpressions.
387+
// Extracts all distinct aggregate expressions from the resultExpressions.
382388
val aggregateExpressions = resultExpressions.flatMap { expr =>
383389
expr.collect {
384390
case agg: AggregateExpression2 => agg
@@ -388,12 +394,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
388394
// to the corresponding attribute of the function.
389395
val aggregateFunctionMap = aggregateExpressions.map { agg =>
390396
val aggregateFunction = agg.aggregateFunction
391-
aggregateFunction -> Alias(aggregateFunction, aggregateFunction.toString)().toAttribute
397+
(aggregateFunction, agg.isDistinct) ->
398+
Alias(aggregateFunction, aggregateFunction.toString)().toAttribute
392399
}.toMap
393400

394401
val (functionsWithDistinct, functionsWithoutDistinct) =
395402
aggregateExpressions.partition(_.isDistinct)
396-
println("functionsWithDistinct " + functionsWithDistinct)
397403
if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) {
398404
// This is a sanity check. We should not reach here since we check the same thing in
399405
// CheckAggregateFunction.

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/aggregateOperators.scala

Lines changed: 38 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,6 @@ case class Aggregate2Sort(
3434
child: SparkPlan)
3535
extends UnaryNode {
3636

37-
/** Indicates if this operator is for partial aggregations. */
38-
39-
4037
override def references: AttributeSet = {
4138
val referencesInResults =
4239
AttributeSet(resultExpressions.flatMap(_.references)) -- AttributeSet(aggregateAttributes)
@@ -55,8 +52,18 @@ case class Aggregate2Sort(
5552
}
5653
}
5754

58-
override def requiredChildOrdering: Seq[Seq[SortOrder]] =
55+
override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
56+
// TODO: We should not sort the input rows if they are just in reversed order.
5957
groupingExpressions.map(SortOrder(_, Ascending)) :: Nil
58+
}
59+
60+
override def outputOrdering: Seq[SortOrder] = {
61+
// It is possible that the child.outputOrdering starts with the required
62+
// ordering expressions (e.g. we require [a] as the sort expression and the
63+
// child's outputOrdering is [a, b]). We can only guarantee the output rows
64+
// are sorted by values of groupingExpressions.
65+
groupingExpressions.map(SortOrder(_, Ascending))
66+
}
6067

6168
override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
6269

@@ -69,42 +76,46 @@ case class Aggregate2Sort(
6976
child.output,
7077
iter)
7178
} else {
72-
val partialAggregation: Boolean = {
79+
val aggregationIterator: SortAggregationIterator = {
7380
aggregateExpressions.map(_.mode).distinct.toList match {
74-
case Partial :: Nil => true
75-
case Final :: Nil => false
76-
TODO: HANDLE PARTIAL MERGE
81+
case Partial :: Nil =>
82+
new PartialSortAggregationIterator(
83+
groupingExpressions,
84+
aggregateExpressions,
85+
newMutableProjection,
86+
child.output,
87+
iter)
88+
case PartialMerge :: Nil =>
89+
new PartialMergeSortAggregationIterator(
90+
groupingExpressions,
91+
aggregateExpressions,
92+
newMutableProjection,
93+
child.output,
94+
iter)
95+
case Final :: Nil =>
96+
new FinalSortAggregationIterator(
97+
groupingExpressions,
98+
aggregateExpressions,
99+
aggregateAttributes,
100+
resultExpressions,
101+
newMutableProjection,
102+
child.output,
103+
iter)
77104
case other =>
78105
sys.error(
79106
s"Could not evaluate ${aggregateExpressions} because we do not support evaluate " +
80107
s"modes $other in this operator.")
81108
}
82109
}
83-
val aggregationIterator =
84-
if (partialAggregation) {
85-
new PartialSortAggregationIterator(
86-
groupingExpressions,
87-
aggregateExpressions,
88-
newMutableProjection,
89-
child.output,
90-
iter)
91-
} else {
92-
new FinalSortAggregationIterator(
93-
groupingExpressions,
94-
aggregateExpressions,
95-
aggregateAttributes,
96-
resultExpressions,
97-
newMutableProjection,
98-
child.output,
99-
iter)
100-
}
110+
101111
aggregationIterator
102112
}
103113
}
104114
}
105115
}
106116

107117
case class FinalAndCompleteAggregate2Sort(
118+
previousGroupingExpressions: Seq[NamedExpression],
108119
groupingExpressions: Seq[NamedExpression],
109120
finalAggregateExpressions: Seq[AggregateExpression2],
110121
finalAggregateAttributes: Seq[Attribute],
@@ -143,6 +154,7 @@ case class FinalAndCompleteAggregate2Sort(
143154
child.execute().mapPartitions { iter =>
144155

145156
new FinalAndCompleteSortAggregationIterator(
157+
previousGroupingExpressions.length,
146158
groupingExpressions,
147159
finalAggregateExpressions,
148160
finalAggregateAttributes,

0 commit comments

Comments
 (0)