Skip to content

Commit 68b8ee9

Browse files
committed
Support single distinct column set. WIP
1 parent 3013579 commit 68b8ee9

File tree

12 files changed

+680
-205
lines changed

12 files changed

+680
-205
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,11 +266,12 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
266266
}
267267
}
268268
| ident ~ ("(" ~> repsep(expression, ",")) <~ ")" ^^
269-
{ case udfName ~ exprs => UnresolvedFunction(udfName, exprs) }
269+
{ case udfName ~ exprs => UnresolvedFunction(udfName, exprs, isDistinct = false) }
270270
| ident ~ ("(" ~ DISTINCT ~> repsep(expression, ",")) <~ ")" ^^ { case udfName ~ exprs =>
271271
lexical.normalizeKeyword(udfName) match {
272272
case "sum" => SumDistinct(exprs.head)
273273
case "count" => CountDistinct(exprs)
274+
case name => UnresolvedFunction(name, exprs, isDistinct = true)
274275
case _ => throw new AnalysisException(s"function $udfName does not support DISTINCT")
275276
}
276277
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql.catalyst.analysis
1919

2020
import org.apache.spark.sql.AnalysisException
21-
import org.apache.spark.sql.catalyst.expressions.aggregate2.{Complete, AggregateExpression2, AggregateFunction2}
21+
import org.apache.spark.sql.catalyst.expressions.aggregate2.{DistinctAggregateExpression1, Complete, AggregateExpression2, AggregateFunction2}
2222
import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf}
2323
import org.apache.spark.sql.catalyst.expressions._
2424
import org.apache.spark.sql.catalyst.plans.logical._
@@ -278,7 +278,7 @@ class Analyzer(
278278
Project(
279279
projectList.flatMap {
280280
case s: Star => s.expand(child.output, resolver)
281-
case UnresolvedAlias(f @ UnresolvedFunction(_, args)) if containsStar(args) =>
281+
case UnresolvedAlias(f @ UnresolvedFunction(_, args, _)) if containsStar(args) =>
282282
val expandedArgs = args.flatMap {
283283
case s: Star => s.expand(child.output, resolver)
284284
case o => o :: Nil
@@ -518,10 +518,12 @@ class Analyzer(
518518
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
519519
case q: LogicalPlan =>
520520
q transformExpressions {
521-
case u @ UnresolvedFunction(name, children) =>
521+
case u @ UnresolvedFunction(name, children, isDistinct) =>
522522
withPosition(u) {
523523
registry.lookupFunction(name, children) match {
524-
case agg2: AggregateFunction2 => AggregateExpression2(agg2, Complete, false)
524+
case agg2: AggregateFunction2 => AggregateExpression2(agg2, Complete, isDistinct)
525+
case agg1: AggregateExpression1 if isDistinct =>
526+
DistinctAggregateExpression1(agg1)
525527
case other => other
526528
}
527529
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,10 @@ object UnresolvedAttribute {
7373
def quoted(name: String): UnresolvedAttribute = new UnresolvedAttribute(Seq(name))
7474
}
7575

76-
case class UnresolvedFunction(name: String, children: Seq[Expression])
76+
case class UnresolvedFunction(
77+
name: String,
78+
children: Seq[Expression],
79+
isDistinct: Boolean)
7780
extends Expression with Unevaluable {
7881

7982
override def dataType: DataType = throw new UnresolvedException(this, "dataType")

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,16 @@ private[sql] case object NoOp extends Expression with Unevaluable {
6868
override def children: Seq[Expression] = Nil
6969
}
7070

71+
private[sql] case class DistinctAggregateExpression1(
72+
aggregateExpression: AggregateExpression1) extends AggregateExpression {
73+
override def children: Seq[Expression] = aggregateExpression :: Nil
74+
override def dataType: DataType = aggregateExpression.dataType
75+
override def foldable: Boolean = aggregateExpression.foldable
76+
override def nullable: Boolean = aggregateExpression.nullable
77+
78+
override def toString: String = s"DISTINCT ${aggregateExpression.toString}"
79+
}
80+
7181
/**
7282
* A container for an [[AggregateFunction2]] with its [[AggregateMode]] and a field
7383
* (`isDistinct`) indicating if DISTINCT keyword is specified for this function.

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 200 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.plans._
2525
import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan}
2626
import org.apache.spark.sql.catalyst.plans.physical._
2727
import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation}
28-
import org.apache.spark.sql.execution.aggregate2.Aggregate2Sort
28+
import org.apache.spark.sql.execution.aggregate2.{FinalAndCompleteAggregate2Sort, Aggregate2Sort}
2929
import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand}
3030
import org.apache.spark.sql.parquet._
3131
import org.apache.spark.sql.sources.{CreateTableUsing, CreateTempTableUsing, DescribeCommand => LogicalDescribeCommand, _}
@@ -200,6 +200,181 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
200200
* Used to plan the aggregate operator for expressions based on the AggregateFunction2 interface.
201201
*/
202202
object AggregateOperator2 extends Strategy {
203+
private def planAggregateWithoutDistinct(
204+
groupingExpressions: Seq[Expression],
205+
aggregateExpressions: Seq[AggregateExpression2],
206+
aggregateFunctionMap: Map[AggregateFunction2, Attribute],
207+
resultExpressions: Seq[NamedExpression],
208+
child: SparkPlan): Seq[SparkPlan] = {
209+
// 1. Create an Aggregate Operator for partial aggregations.
210+
val namedGroupingExpressions = groupingExpressions.map {
211+
case ne: NamedExpression => ne -> ne
212+
// If the expression is not a NamedExpressions, we add an alias.
213+
// So, when we generate the result of the operator, the Aggregate Operator
214+
// can directly get the Seq of attributes representing the grouping expressions.
215+
case other =>
216+
val withAlias = Alias(other, other.toString)()
217+
other -> withAlias
218+
}
219+
val groupExpressionMap = namedGroupingExpressions.toMap
220+
val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute)
221+
val partialAggregateExpressions = aggregateExpressions.map {
222+
case AggregateExpression2(aggregateFunction, mode, isDistinct) =>
223+
AggregateExpression2(aggregateFunction, Partial, isDistinct)
224+
}
225+
val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg =>
226+
agg.aggregateFunction.bufferAttributes
227+
}
228+
val partialAggregate =
229+
Aggregate2Sort(
230+
None: Option[Seq[Expression]],
231+
namedGroupingExpressions.map(_._2),
232+
partialAggregateExpressions,
233+
partialAggregateAttributes,
234+
namedGroupingAttributes ++ partialAggregateAttributes,
235+
child)
236+
237+
// 2. Create an Aggregate Operator for final aggregations.
238+
val finalAggregateExpressions = aggregateExpressions.map {
239+
case AggregateExpression2(aggregateFunction, mode, isDistinct) =>
240+
AggregateExpression2(aggregateFunction, Final, isDistinct)
241+
}
242+
val finalAggregateAttributes =
243+
finalAggregateExpressions.map {
244+
expr => aggregateFunctionMap(expr.aggregateFunction)
245+
}
246+
val rewrittenResultExpressions = resultExpressions.map { expr =>
247+
expr.transform {
248+
case agg: AggregateExpression2 =>
249+
aggregateFunctionMap(agg.aggregateFunction).toAttribute
250+
case expression if groupExpressionMap.contains(expression) =>
251+
groupExpressionMap(expression).toAttribute
252+
}.asInstanceOf[NamedExpression]
253+
}
254+
val finalAggregate = Aggregate2Sort(
255+
Some(namedGroupingAttributes),
256+
namedGroupingAttributes,
257+
finalAggregateExpressions,
258+
finalAggregateAttributes,
259+
rewrittenResultExpressions,
260+
partialAggregate)
261+
262+
finalAggregate :: Nil
263+
}
264+
265+
private def planAggregateWithOneDistinct(
266+
groupingExpressions: Seq[Expression],
267+
functionsWithDistinct: Seq[AggregateExpression2],
268+
functionsWithoutDistinct: Seq[AggregateExpression2],
269+
aggregateFunctionMap: Map[AggregateFunction2, Attribute],
270+
resultExpressions: Seq[NamedExpression],
271+
child: SparkPlan): Seq[SparkPlan] = {
272+
273+
// 1. Create an Aggregate Operator for partial aggregations.
274+
// The grouping expressions are original groupingExpressions and
275+
// distinct columns. For example, for avg(distinct value) ... group by key
276+
// the grouping expressions of this Aggregate Operator will be [key, value].
277+
val namedGroupingExpressions = groupingExpressions.map {
278+
case ne: NamedExpression => ne -> ne
279+
// If the expression is not a NamedExpressions, we add an alias.
280+
// So, when we generate the result of the operator, the Aggregate Operator
281+
// can directly get the Seq of attributes representing the grouping expressions.
282+
case other =>
283+
val withAlias = Alias(other, other.toString)()
284+
other -> withAlias
285+
}
286+
val groupExpressionMap = namedGroupingExpressions.toMap
287+
val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute)
288+
289+
// It is safe to call head at here since functionsWithDistinct has at least one
290+
// AggregateExpression2.
291+
val distinctColumnExpressions =
292+
functionsWithDistinct.head.aggregateFunction.children
293+
val namedDistinctColumnExpressions = distinctColumnExpressions.map {
294+
case ne: NamedExpression => ne -> ne
295+
case other =>
296+
val withAlias = Alias(other, other.toString)()
297+
other -> withAlias
298+
}
299+
val distinctColumnExpressionMap = namedDistinctColumnExpressions.toMap
300+
val distinctColumnAttributes = namedDistinctColumnExpressions.map(_._2.toAttribute)
301+
302+
val partialAggregateExpressions = functionsWithoutDistinct.map {
303+
case AggregateExpression2(aggregateFunction, mode, _) =>
304+
AggregateExpression2(aggregateFunction, Partial, false)
305+
}
306+
val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg =>
307+
agg.aggregateFunction.bufferAttributes
308+
}
309+
println("namedDistinctColumnExpressions " + namedDistinctColumnExpressions)
310+
val partialAggregate =
311+
Aggregate2Sort(
312+
None: Option[Seq[Expression]],
313+
(namedGroupingExpressions ++ namedDistinctColumnExpressions).map(_._2),
314+
partialAggregateExpressions,
315+
partialAggregateAttributes,
316+
namedGroupingAttributes ++ distinctColumnAttributes ++ partialAggregateAttributes,
317+
child)
318+
319+
// 2. Create an Aggregate Operator for partial merge aggregations.
320+
val partialMergeAggregateExpressions = functionsWithoutDistinct.map {
321+
case AggregateExpression2(aggregateFunction, mode, _) =>
322+
AggregateExpression2(aggregateFunction, PartialMerge, false)
323+
}
324+
val partialMergeAggregateAttributes =
325+
partialMergeAggregateExpressions.map {
326+
expr => aggregateFunctionMap(expr.aggregateFunction)
327+
}
328+
val partialMergeAggregate =
329+
Aggregate2Sort(
330+
Some(namedGroupingAttributes),
331+
namedGroupingAttributes ++ distinctColumnAttributes,
332+
partialMergeAggregateExpressions,
333+
partialMergeAggregateAttributes,
334+
namedGroupingAttributes ++ distinctColumnAttributes ++ partialMergeAggregateAttributes,
335+
partialAggregate)
336+
337+
// 3. Create an Aggregate Operator for partial merge aggregations.
338+
val finalAggregateExpressions = functionsWithoutDistinct.map {
339+
Need to replace the children to distinctColumnAttributes
340+
case AggregateExpression2(aggregateFunction, mode, _) =>
341+
AggregateExpression2(aggregateFunction, Final, false)
342+
}
343+
val finalAggregateAttributes =
344+
finalAggregateExpressions.map {
345+
expr => aggregateFunctionMap(expr.aggregateFunction)
346+
}
347+
val completeAggregateExpressions = functionsWithDistinct.map {
348+
case AggregateExpression2(aggregateFunction, mode, _) =>
349+
AggregateExpression2(aggregateFunction, Complete, false)
350+
}
351+
val completeAggregateAttributes =
352+
completeAggregateExpressions.map {
353+
expr => aggregateFunctionMap(expr.aggregateFunction)
354+
}
355+
356+
val rewrittenResultExpressions = resultExpressions.map { expr =>
357+
expr.transform {
358+
case agg: AggregateExpression2 =>
359+
aggregateFunctionMap(agg.aggregateFunction).toAttribute
360+
case expression if groupExpressionMap.contains(expression) =>
361+
groupExpressionMap(expression).toAttribute
362+
case expression if distinctColumnExpressionMap.contains(expression) =>
363+
distinctColumnExpressionMap(expression).toAttribute
364+
}.asInstanceOf[NamedExpression]
365+
}
366+
val finalAndCompleteAggregate = FinalAndCompleteAggregate2Sort(
367+
namedGroupingAttributes,
368+
finalAggregateExpressions,
369+
finalAggregateAttributes,
370+
completeAggregateExpressions,
371+
completeAggregateAttributes,
372+
rewrittenResultExpressions,
373+
partialMergeAggregate)
374+
375+
finalAndCompleteAggregate :: Nil
376+
}
377+
203378
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
204379
case logical.Aggregate(groupingExpressions, resultExpressions, child)
205380
if sqlContext.conf.useSqlAggregate2 =>
@@ -216,58 +391,33 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
216391
aggregateFunction -> Alias(aggregateFunction, aggregateFunction.toString)().toAttribute
217392
}.toMap
218393

219-
// 2. Create an Aggregate Operator for partial aggregations.
220-
val namedGroupingExpressions = groupingExpressions.map {
221-
case ne: NamedExpression => ne -> ne
222-
// If the expression is not a NamedExpressions, we add an alias.
223-
// So, when we generate the result of the operator, the Aggregate Operator
224-
// can directly get the Seq of attributes representing the grouping expressions.
225-
case other =>
226-
val withAlias = Alias(other, other.toString)()
227-
other -> withAlias
228-
}
229-
val groupExpressionMap = namedGroupingExpressions.toMap
230-
val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute)
231-
val partialAggregateExpressions = aggregateExpressions.map {
232-
case AggregateExpression2(aggregateFunction, mode, isDistinct) =>
233-
AggregateExpression2(aggregateFunction, Partial, isDistinct)
394+
val (functionsWithDistinct, functionsWithoutDistinct) =
395+
aggregateExpressions.partition(_.isDistinct)
396+
println("functionsWithDistinct " + functionsWithDistinct)
397+
if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) {
398+
// This is a sanity check. We should not reach here since we check the same thing in
399+
// CheckAggregateFunction.
400+
sys.error("Having more than one distinct column sets is not allowed.")
234401
}
235-
val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg =>
236-
agg.aggregateFunction.bufferAttributes
237-
}
238-
val partialAggregate =
239-
Aggregate2Sort(
240-
namedGroupingExpressions.map(_._2),
241-
partialAggregateExpressions,
242-
partialAggregateAttributes,
243-
namedGroupingAttributes ++ partialAggregateAttributes,
244-
planLater(child))
245-
246-
// 3. Create an Aggregate Operator for final aggregations.
247-
val finalAggregateExpressions = aggregateExpressions.map {
248-
case AggregateExpression2(aggregateFunction, mode, isDistinct) =>
249-
AggregateExpression2(aggregateFunction, Final, isDistinct)
250-
}
251-
val finalAggregateAttributes =
252-
finalAggregateExpressions.map {
253-
expr => aggregateFunctionMap(expr.aggregateFunction)
402+
val aggregate =
403+
if (functionsWithDistinct.isEmpty) {
404+
planAggregateWithoutDistinct(
405+
groupingExpressions,
406+
aggregateExpressions,
407+
aggregateFunctionMap,
408+
resultExpressions,
409+
planLater(child))
410+
} else {
411+
planAggregateWithOneDistinct(
412+
groupingExpressions,
413+
functionsWithDistinct,
414+
functionsWithoutDistinct,
415+
aggregateFunctionMap,
416+
resultExpressions,
417+
planLater(child))
254418
}
255-
val rewrittenResultExpressions = resultExpressions.map { expr =>
256-
expr.transform {
257-
case agg: AggregateExpression2 =>
258-
aggregateFunctionMap(agg.aggregateFunction).toAttribute
259-
case expression if groupExpressionMap.contains(expression) =>
260-
groupExpressionMap(expression).toAttribute
261-
}.asInstanceOf[NamedExpression]
262-
}
263-
val finalAggregate = Aggregate2Sort(
264-
namedGroupingAttributes,
265-
finalAggregateExpressions,
266-
finalAggregateAttributes,
267-
rewrittenResultExpressions,
268-
partialAggregate)
269419

270-
finalAggregate :: Nil
420+
aggregate
271421
case _ => Nil
272422
}
273423
}

0 commit comments

Comments
 (0)