Skip to content

Commit c03299a

Browse files
yhuairxin
authored andcommitted
[SPARK-4233] [SPARK-4367] [SPARK-3947] [SPARK-3056] [SQL] Aggregation Improvement
This is the first PR for the aggregation improvement, which is tracked by https://issues.apache.org/jira/browse/SPARK-4366 (umbrella JIRA). This PR contains work for its subtasks, SPARK-3056, SPARK-3947, SPARK-4233, and SPARK-4367. This PR introduces a new code path for evaluating aggregate functions. This code path is guarded by `spark.sql.useAggregate2` and by default the value of this flag is true. This new code path contains: * A new aggregate function interface (`AggregateFunction2`) and 7 built-int aggregate functions based on this new interface (`AVG`, `COUNT`, `FIRST`, `LAST`, `MAX`, `MIN`, `SUM`) * A UDAF interface (`UserDefinedAggregateFunction`) based on the new code path and two example UDAFs (`MyDoubleAvg` and `MyDoubleSum`). * A sort-based aggregate operator (`Aggregate2Sort`) for the new aggregate function interface . * A sort-based aggregate operator (`FinalAndCompleteAggregate2Sort`) for distinct aggregations (for distinct aggregations the query plan will use `Aggregate2Sort` and `FinalAndCompleteAggregate2Sort` together). With this change, `spark.sql.useAggregate2` is `true`, the flow of compiling an aggregation query is: 1. Our analyzer looks up functions and returns aggregate functions built based on the old aggregate function interface. 2. When our planner is compiling the physical plan, it tries try to convert all aggregate functions to the ones built based on the new interface. The planner will fallback to the old code path if any of the following two conditions is true: * code-gen is disabled. * there is any function that cannot be converted (right now, Hive UDAFs). * the schema of grouping expressions contain any complex data type. * There are multiple distinct columns. Right now, the new code path handles a single distinct column in the query (you can have multiple aggregate functions using that distinct column). For a query having a aggregate function with DISTINCT and regular aggregate functions, the generated plan will do partial aggregations for those regular aggregate function. Thanks chenghao-intel for his initial work on it. Author: Yin Huai <[email protected]> Author: Michael Armbrust <[email protected]> Closes apache#7458 from yhuai/UDAF and squashes the following commits: 7865f5e [Yin Huai] Put the catalyst expression in the comment of the generated code for it. b04d6c8 [Yin Huai] Remove unnecessary change. f1d5901 [Yin Huai] Merge remote-tracking branch 'upstream/master' into UDAF 35b0520 [Yin Huai] Use semanticEquals to replace grouping expressions in the output of the aggregate operator. 3b43b24 [Yin Huai] bug fix. 00eb298 [Yin Huai] Make it compile. a3ca551 [Yin Huai] Merge remote-tracking branch 'upstream/master' into UDAF e0afca3 [Yin Huai] Gracefully fallback to old aggregation code path. 8a8ac4a [Yin Huai] Merge remote-tracking branch 'upstream/master' into UDAF 88c7d4d [Yin Huai] Enable spark.sql.useAggregate2 by default for testing purpose. dc96fd1 [Yin Huai] Many updates: 85c9c4b [Yin Huai] newline. 43de3de [Yin Huai] Merge remote-tracking branch 'upstream/master' into UDAF c3614d7 [Yin Huai] Handle single distinct column. 68b8ee9 [Yin Huai] Support single distinct column set. WIP 3013579 [Yin Huai] Format. d678aee [Yin Huai] Remove AggregateExpressionSuite.scala since our built-in aggregate functions will be based on AlgebraicAggregate and we need to have another way to test it. e243ca6 [Yin Huai] Add aggregation iterators. a101960 [Yin Huai] Change MyJavaUDAF to MyDoubleSum. 594cdf5 [Yin Huai] Change existing AggregateExpression to AggregateExpression1 and add an AggregateExpression as the common interface for both AggregateExpression1 and AggregateExpression2. 380880f [Yin Huai] Merge remote-tracking branch 'upstream/master' into UDAF 0a827b3 [Yin Huai] Add comments and doc. Move some classes to the right places. a19fea6 [Yin Huai] Add UDAF interface. 262d4c4 [Yin Huai] Make it compile. b2e358e [Yin Huai] Merge remote-tracking branch 'upstream/master' into UDAF 6edb5ac [Yin Huai] Format update. 70b169c [Yin Huai] Remove groupOrdering. 4721936 [Yin Huai] Add CheckAggregateFunction to extendedCheckRules. d821a34 [Yin Huai] Cleanup. 32aea9c [Yin Huai] Merge remote-tracking branch 'upstream/master' into UDAF 5b46d41 [Yin Huai] Bug fix. aff9534 [Yin Huai] Make Aggregate2Sort work with both algebraic AggregateFunctions and non-algebraic AggregateFunctions. 2857b55 [Yin Huai] Merge remote-tracking branch 'upstream/master' into UDAF 4435f20 [Yin Huai] Add ConvertAggregateFunction to HiveContext's analyzer. 1b490ed [Michael Armbrust] make hive test 8cfa6a9 [Michael Armbrust] add test 1b0bb3f [Yin Huai] Do not bind references in AlgebraicAggregate and use code gen for all places. 072209f [Yin Huai] Bug fix: Handle expressions in grouping columns that are not attribute references. f7d9e54 [Michael Armbrust] Merge remote-tracking branch 'apache/master' into UDAF 39ee975 [Yin Huai] Code cleanup: Remove unnecesary AttributeReferences. b7720ba [Yin Huai] Add an analysis rule to convert aggregate function to the new version. 5c00f3f [Michael Armbrust] First draft of codegen 6bbc6ba [Michael Armbrust] now with correct answers\! f7996d0 [Michael Armbrust] Add AlgebraicAggregate dded1c5 [Yin Huai] wip
1 parent f4785f5 commit c03299a

39 files changed

+3087
-100
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,11 +266,12 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
266266
}
267267
}
268268
| ident ~ ("(" ~> repsep(expression, ",")) <~ ")" ^^
269-
{ case udfName ~ exprs => UnresolvedFunction(udfName, exprs) }
269+
{ case udfName ~ exprs => UnresolvedFunction(udfName, exprs, isDistinct = false) }
270270
| ident ~ ("(" ~ DISTINCT ~> repsep(expression, ",")) <~ ")" ^^ { case udfName ~ exprs =>
271271
lexical.normalizeKeyword(udfName) match {
272272
case "sum" => SumDistinct(exprs.head)
273273
case "count" => CountDistinct(exprs)
274+
case name => UnresolvedFunction(name, exprs, isDistinct = true)
274275
case _ => throw new AnalysisException(s"function $udfName does not support DISTINCT")
275276
}
276277
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql.catalyst.analysis
1919

2020
import org.apache.spark.sql.AnalysisException
21+
import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2, AggregateFunction2}
2122
import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf}
2223
import org.apache.spark.sql.catalyst.expressions._
2324
import org.apache.spark.sql.catalyst.plans.logical._
@@ -277,7 +278,7 @@ class Analyzer(
277278
Project(
278279
projectList.flatMap {
279280
case s: Star => s.expand(child.output, resolver)
280-
case UnresolvedAlias(f @ UnresolvedFunction(_, args)) if containsStar(args) =>
281+
case UnresolvedAlias(f @ UnresolvedFunction(_, args, _)) if containsStar(args) =>
281282
val expandedArgs = args.flatMap {
282283
case s: Star => s.expand(child.output, resolver)
283284
case o => o :: Nil
@@ -517,9 +518,26 @@ class Analyzer(
517518
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
518519
case q: LogicalPlan =>
519520
q transformExpressions {
520-
case u @ UnresolvedFunction(name, children) =>
521+
case u @ UnresolvedFunction(name, children, isDistinct) =>
521522
withPosition(u) {
522-
registry.lookupFunction(name, children)
523+
registry.lookupFunction(name, children) match {
524+
// We get an aggregate function built based on AggregateFunction2 interface.
525+
// So, we wrap it in AggregateExpression2.
526+
case agg2: AggregateFunction2 => AggregateExpression2(agg2, Complete, isDistinct)
527+
// Currently, our old aggregate function interface supports SUM(DISTINCT ...)
528+
// and COUTN(DISTINCT ...).
529+
case sumDistinct: SumDistinct => sumDistinct
530+
case countDistinct: CountDistinct => countDistinct
531+
// DISTINCT is not meaningful with Max and Min.
532+
case max: Max if isDistinct => max
533+
case min: Min if isDistinct => min
534+
// For other aggregate functions, DISTINCT keyword is not supported for now.
535+
// Once we converted to the new code path, we will allow using DISTINCT keyword.
536+
case other if isDistinct =>
537+
failAnalysis(s"$name does not support DISTINCT keyword.")
538+
// If it does not have DISTINCT keyword, we will return it as is.
539+
case other => other
540+
}
523541
}
524542
}
525543
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.analysis
1919

2020
import org.apache.spark.sql.AnalysisException
2121
import org.apache.spark.sql.catalyst.expressions._
22+
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2
2223
import org.apache.spark.sql.catalyst.plans.logical._
2324
import org.apache.spark.sql.types._
2425

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,10 @@ object UnresolvedAttribute {
7373
def quoted(name: String): UnresolvedAttribute = new UnresolvedAttribute(Seq(name))
7474
}
7575

76-
case class UnresolvedFunction(name: String, children: Seq[Expression])
76+
case class UnresolvedFunction(
77+
name: String,
78+
children: Seq[Expression],
79+
isDistinct: Boolean)
7780
extends Expression with Unevaluable {
7881

7982
override def dataType: DataType = throw new UnresolvedException(this, "dataType")

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ import org.apache.spark.sql.types._
3232
case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
3333
extends LeafExpression with NamedExpression {
3434

35-
override def toString: String = s"input[$ordinal]"
35+
override def toString: String = s"input[$ordinal, $dataType]"
3636

3737
override def eval(input: InternalRow): Any = input(ordinal)
3838

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@ abstract class Expression extends TreeNode[Expression] {
9696
val primitive = ctx.freshName("primitive")
9797
val ve = GeneratedExpressionCode("", isNull, primitive)
9898
ve.code = genCode(ctx, ve)
99-
ve
99+
// Add `this` in the comment.
100+
ve.copy(s"/* $this */\n" + ve.code)
100101
}
101102

102103
/**
Lines changed: 292 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,292 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.expressions.aggregate
19+
20+
import org.apache.spark.sql.catalyst.dsl.expressions._
21+
import org.apache.spark.sql.catalyst.expressions._
22+
import org.apache.spark.sql.types._
23+
24+
case class Average(child: Expression) extends AlgebraicAggregate {
25+
26+
override def children: Seq[Expression] = child :: Nil
27+
28+
override def nullable: Boolean = true
29+
30+
// Return data type.
31+
override def dataType: DataType = resultType
32+
33+
// Expected input data type.
34+
// TODO: Once we remove the old code path, we can use our analyzer to cast NullType
35+
// to the default data type of the NumericType.
36+
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType))
37+
38+
private val resultType = child.dataType match {
39+
case DecimalType.Fixed(precision, scale) =>
40+
DecimalType(precision + 4, scale + 4)
41+
case DecimalType.Unlimited => DecimalType.Unlimited
42+
case _ => DoubleType
43+
}
44+
45+
private val sumDataType = child.dataType match {
46+
case _ @ DecimalType() => DecimalType.Unlimited
47+
case _ => DoubleType
48+
}
49+
50+
private val currentSum = AttributeReference("currentSum", sumDataType)()
51+
private val currentCount = AttributeReference("currentCount", LongType)()
52+
53+
override val bufferAttributes = currentSum :: currentCount :: Nil
54+
55+
override val initialValues = Seq(
56+
/* currentSum = */ Cast(Literal(0), sumDataType),
57+
/* currentCount = */ Literal(0L)
58+
)
59+
60+
override val updateExpressions = Seq(
61+
/* currentSum = */
62+
Add(
63+
currentSum,
64+
Coalesce(Cast(child, sumDataType) :: Cast(Literal(0), sumDataType) :: Nil)),
65+
/* currentCount = */ If(IsNull(child), currentCount, currentCount + 1L)
66+
)
67+
68+
override val mergeExpressions = Seq(
69+
/* currentSum = */ currentSum.left + currentSum.right,
70+
/* currentCount = */ currentCount.left + currentCount.right
71+
)
72+
73+
// If all input are nulls, currentCount will be 0 and we will get null after the division.
74+
override val evaluateExpression = Cast(currentSum, resultType) / Cast(currentCount, resultType)
75+
}
76+
77+
case class Count(child: Expression) extends AlgebraicAggregate {
78+
override def children: Seq[Expression] = child :: Nil
79+
80+
override def nullable: Boolean = false
81+
82+
// Return data type.
83+
override def dataType: DataType = LongType
84+
85+
// Expected input data type.
86+
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
87+
88+
private val currentCount = AttributeReference("currentCount", LongType)()
89+
90+
override val bufferAttributes = currentCount :: Nil
91+
92+
override val initialValues = Seq(
93+
/* currentCount = */ Literal(0L)
94+
)
95+
96+
override val updateExpressions = Seq(
97+
/* currentCount = */ If(IsNull(child), currentCount, currentCount + 1L)
98+
)
99+
100+
override val mergeExpressions = Seq(
101+
/* currentCount = */ currentCount.left + currentCount.right
102+
)
103+
104+
override val evaluateExpression = Cast(currentCount, LongType)
105+
}
106+
107+
case class First(child: Expression) extends AlgebraicAggregate {
108+
109+
override def children: Seq[Expression] = child :: Nil
110+
111+
override def nullable: Boolean = true
112+
113+
// First is not a deterministic function.
114+
override def deterministic: Boolean = false
115+
116+
// Return data type.
117+
override def dataType: DataType = child.dataType
118+
119+
// Expected input data type.
120+
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
121+
122+
private val first = AttributeReference("first", child.dataType)()
123+
124+
override val bufferAttributes = first :: Nil
125+
126+
override val initialValues = Seq(
127+
/* first = */ Literal.create(null, child.dataType)
128+
)
129+
130+
override val updateExpressions = Seq(
131+
/* first = */ If(IsNull(first), child, first)
132+
)
133+
134+
override val mergeExpressions = Seq(
135+
/* first = */ If(IsNull(first.left), first.right, first.left)
136+
)
137+
138+
override val evaluateExpression = first
139+
}
140+
141+
case class Last(child: Expression) extends AlgebraicAggregate {
142+
143+
override def children: Seq[Expression] = child :: Nil
144+
145+
override def nullable: Boolean = true
146+
147+
// Last is not a deterministic function.
148+
override def deterministic: Boolean = false
149+
150+
// Return data type.
151+
override def dataType: DataType = child.dataType
152+
153+
// Expected input data type.
154+
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
155+
156+
private val last = AttributeReference("last", child.dataType)()
157+
158+
override val bufferAttributes = last :: Nil
159+
160+
override val initialValues = Seq(
161+
/* last = */ Literal.create(null, child.dataType)
162+
)
163+
164+
override val updateExpressions = Seq(
165+
/* last = */ If(IsNull(child), last, child)
166+
)
167+
168+
override val mergeExpressions = Seq(
169+
/* last = */ If(IsNull(last.right), last.left, last.right)
170+
)
171+
172+
override val evaluateExpression = last
173+
}
174+
175+
case class Max(child: Expression) extends AlgebraicAggregate {
176+
177+
override def children: Seq[Expression] = child :: Nil
178+
179+
override def nullable: Boolean = true
180+
181+
// Return data type.
182+
override def dataType: DataType = child.dataType
183+
184+
// Expected input data type.
185+
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
186+
187+
private val max = AttributeReference("max", child.dataType)()
188+
189+
override val bufferAttributes = max :: Nil
190+
191+
override val initialValues = Seq(
192+
/* max = */ Literal.create(null, child.dataType)
193+
)
194+
195+
override val updateExpressions = Seq(
196+
/* max = */ If(IsNull(child), max, If(IsNull(max), child, Greatest(Seq(max, child))))
197+
)
198+
199+
override val mergeExpressions = {
200+
val greatest = Greatest(Seq(max.left, max.right))
201+
Seq(
202+
/* max = */ If(IsNull(max.right), max.left, If(IsNull(max.left), max.right, greatest))
203+
)
204+
}
205+
206+
override val evaluateExpression = max
207+
}
208+
209+
case class Min(child: Expression) extends AlgebraicAggregate {
210+
211+
override def children: Seq[Expression] = child :: Nil
212+
213+
override def nullable: Boolean = true
214+
215+
// Return data type.
216+
override def dataType: DataType = child.dataType
217+
218+
// Expected input data type.
219+
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
220+
221+
private val min = AttributeReference("min", child.dataType)()
222+
223+
override val bufferAttributes = min :: Nil
224+
225+
override val initialValues = Seq(
226+
/* min = */ Literal.create(null, child.dataType)
227+
)
228+
229+
override val updateExpressions = Seq(
230+
/* min = */ If(IsNull(child), min, If(IsNull(min), child, Least(Seq(min, child))))
231+
)
232+
233+
override val mergeExpressions = {
234+
val least = Least(Seq(min.left, min.right))
235+
Seq(
236+
/* min = */ If(IsNull(min.right), min.left, If(IsNull(min.left), min.right, least))
237+
)
238+
}
239+
240+
override val evaluateExpression = min
241+
}
242+
243+
case class Sum(child: Expression) extends AlgebraicAggregate {
244+
245+
override def children: Seq[Expression] = child :: Nil
246+
247+
override def nullable: Boolean = true
248+
249+
// Return data type.
250+
override def dataType: DataType = resultType
251+
252+
// Expected input data type.
253+
override def inputTypes: Seq[AbstractDataType] =
254+
Seq(TypeCollection(LongType, DoubleType, DecimalType, NullType))
255+
256+
private val resultType = child.dataType match {
257+
case DecimalType.Fixed(precision, scale) =>
258+
DecimalType(precision + 4, scale + 4)
259+
case DecimalType.Unlimited => DecimalType.Unlimited
260+
case _ => child.dataType
261+
}
262+
263+
private val sumDataType = child.dataType match {
264+
case _ @ DecimalType() => DecimalType.Unlimited
265+
case _ => child.dataType
266+
}
267+
268+
private val currentSum = AttributeReference("currentSum", sumDataType)()
269+
270+
private val zero = Cast(Literal(0), sumDataType)
271+
272+
override val bufferAttributes = currentSum :: Nil
273+
274+
override val initialValues = Seq(
275+
/* currentSum = */ Literal.create(null, sumDataType)
276+
)
277+
278+
override val updateExpressions = Seq(
279+
/* currentSum = */
280+
Coalesce(Seq(Add(Coalesce(Seq(currentSum, zero)), Cast(child, sumDataType)), currentSum))
281+
)
282+
283+
override val mergeExpressions = {
284+
val add = Add(Coalesce(Seq(currentSum.left, zero)), Cast(currentSum.right, sumDataType))
285+
Seq(
286+
/* currentSum = */
287+
Coalesce(Seq(add, currentSum.left))
288+
)
289+
}
290+
291+
override val evaluateExpression = Cast(currentSum, resultType)
292+
}

0 commit comments

Comments
 (0)