Skip to content

Commit dc96fd1

Browse files
committed
Many updates:
1. Implementing COUNT, FIRST, LAST, MAX, MIN, and SUM based on the new interface. 2. Automatically fall back to old aggregation code path if we cannot evaluate the query using the new code path. 3. Refactoring.
1 parent 85c9c4b commit dc96fd1

File tree

18 files changed

+812
-387
lines changed

18 files changed

+812
-387
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql.catalyst.analysis
1919

2020
import org.apache.spark.sql.AnalysisException
21-
import org.apache.spark.sql.catalyst.expressions.aggregate2.{DistinctAggregateExpression1, Complete, AggregateExpression2, AggregateFunction2}
21+
import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2, AggregateFunction2}
2222
import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf}
2323
import org.apache.spark.sql.catalyst.expressions._
2424
import org.apache.spark.sql.catalyst.plans.logical._
@@ -521,9 +521,21 @@ class Analyzer(
521521
case u @ UnresolvedFunction(name, children, isDistinct) =>
522522
withPosition(u) {
523523
registry.lookupFunction(name, children) match {
524+
// We get an aggregate function built based on AggregateFunction2 interface.
525+
// So, we wrap it in AggregateExpression2.
524526
case agg2: AggregateFunction2 => AggregateExpression2(agg2, Complete, isDistinct)
525-
case agg1: AggregateExpression1 if isDistinct =>
526-
DistinctAggregateExpression1(agg1)
527+
// Currently, our old aggregate function interface supports SUM(DISTINCT ...)
528+
// and COUTN(DISTINCT ...).
529+
case sumDistinct: SumDistinct => sumDistinct
530+
case countDistinct: CountDistinct => countDistinct
531+
// DISTINCT is not meaningful with Max and Min.
532+
case max: Max if isDistinct => max
533+
case min: Min if isDistinct => min
534+
// For other aggregate functions, DISTINCT keyword is not supported for now.
535+
// Once we converted to the new code path, we will allow using DISTINCT keyword.
536+
case other if isDistinct =>
537+
failAnalysis(s"$name does not support DISTINCT keyword.")
538+
// If it does not have DISTINCT keyword, we will return it as is.
527539
case other => other
528540
}
529541
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.analysis
1919

2020
import org.apache.spark.sql.AnalysisException
2121
import org.apache.spark.sql.catalyst.expressions._
22-
import org.apache.spark.sql.catalyst.expressions.aggregate2.AggregateExpression2
22+
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2
2323
import org.apache.spark.sql.catalyst.plans.logical._
2424
import org.apache.spark.sql.types._
2525

Lines changed: 291 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,291 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.expressions.aggregate
19+
20+
import org.apache.spark.sql.catalyst.dsl.expressions._
21+
import org.apache.spark.sql.catalyst.expressions._
22+
import org.apache.spark.sql.types._
23+
24+
case class Average(child: Expression) extends AlgebraicAggregate {
25+
26+
override def children: Seq[Expression] = child :: Nil
27+
28+
override def nullable: Boolean = true
29+
30+
// Return data type.
31+
override def dataType: DataType = resultType
32+
33+
// Expected input data type.
34+
// TODO: Once we remove the old code path, we can use our analyzer to cast NullType
35+
// to the default data type of the NumericType.
36+
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType))
37+
38+
private val resultType = child.dataType match {
39+
case DecimalType.Fixed(precision, scale) =>
40+
DecimalType(precision + 4, scale + 4)
41+
case DecimalType.Unlimited => DecimalType.Unlimited
42+
case _ => DoubleType
43+
}
44+
45+
private val sumDataType = child.dataType match {
46+
case _ @ DecimalType() => DecimalType.Unlimited
47+
case _ => DoubleType
48+
}
49+
50+
private val currentSum = AttributeReference("currentSum", sumDataType)()
51+
private val currentCount = AttributeReference("currentCount", LongType)()
52+
53+
override val bufferAttributes = currentSum :: currentCount :: Nil
54+
55+
override val initialValues = Seq(
56+
/* currentSum = */ Cast(Literal(0), sumDataType),
57+
/* currentCount = */ Literal(0L)
58+
)
59+
60+
override val updateExpressions = Seq(
61+
/* currentSum = */
62+
Add(
63+
currentSum,
64+
Coalesce(Cast(child, sumDataType) :: Cast(Literal(0), sumDataType) :: Nil)),
65+
/* currentCount = */ If(IsNull(child), currentCount, currentCount + 1L)
66+
)
67+
68+
override val mergeExpressions = Seq(
69+
/* currentSum = */ currentSum.left + currentSum.right,
70+
/* currentCount = */ currentCount.left + currentCount.right
71+
)
72+
73+
// If all input are nulls, currentCount will be 0 and we will get null after the division.
74+
override val evaluateExpression = Cast(currentSum, resultType) / Cast(currentCount, resultType)
75+
}
76+
77+
case class Count(child: Expression) extends AlgebraicAggregate {
78+
override def children: Seq[Expression] = child :: Nil
79+
80+
override def nullable: Boolean = false
81+
82+
// Return data type.
83+
override def dataType: DataType = LongType
84+
85+
// Expected input data type.
86+
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
87+
88+
private val currentCount = AttributeReference("currentCount", LongType)()
89+
90+
override val bufferAttributes = currentCount :: Nil
91+
92+
override val initialValues = Seq(
93+
/* currentCount = */ Literal(0L)
94+
)
95+
96+
override val updateExpressions = Seq(
97+
/* currentCount = */ If(IsNull(child), currentCount, currentCount + 1L)
98+
)
99+
100+
override val mergeExpressions = Seq(
101+
/* currentCount = */ currentCount.left + currentCount.right
102+
)
103+
104+
override val evaluateExpression = Cast(currentCount, LongType)
105+
}
106+
107+
case class First(child: Expression) extends AlgebraicAggregate {
108+
109+
override def children: Seq[Expression] = child :: Nil
110+
111+
override def nullable: Boolean = true
112+
113+
// First is not a deterministic function.
114+
override def deterministic: Boolean = false
115+
116+
// Return data type.
117+
override def dataType: DataType = child.dataType
118+
119+
// Expected input data type.
120+
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
121+
122+
private val first = AttributeReference("first", child.dataType)()
123+
124+
override val bufferAttributes = first :: Nil
125+
126+
override val initialValues = Seq(
127+
/* first = */ Literal.create(null, child.dataType)
128+
)
129+
130+
override val updateExpressions = Seq(
131+
/* first = */ If(IsNull(first), child, first)
132+
)
133+
134+
override val mergeExpressions = Seq(
135+
/* first = */ If(IsNull(first.left), first.right, first.left)
136+
)
137+
138+
override val evaluateExpression = first
139+
}
140+
141+
case class Last(child: Expression) extends AlgebraicAggregate {
142+
143+
override def children: Seq[Expression] = child :: Nil
144+
145+
override def nullable: Boolean = true
146+
147+
// Last is not a deterministic function.
148+
override def deterministic: Boolean = false
149+
150+
// Return data type.
151+
override def dataType: DataType = child.dataType
152+
153+
// Expected input data type.
154+
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
155+
156+
private val last = AttributeReference("last", child.dataType)()
157+
158+
override val bufferAttributes = last :: Nil
159+
160+
override val initialValues = Seq(
161+
/* last = */ Literal.create(null, child.dataType)
162+
)
163+
164+
override val updateExpressions = Seq(
165+
/* last = */ If(IsNull(child), last, child)
166+
)
167+
168+
override val mergeExpressions = Seq(
169+
/* last = */ If(IsNull(last.right), last.left, last.right)
170+
)
171+
172+
override val evaluateExpression = last
173+
}
174+
175+
case class Max(child: Expression) extends AlgebraicAggregate {
176+
177+
override def children: Seq[Expression] = child :: Nil
178+
179+
override def nullable: Boolean = true
180+
181+
// Return data type.
182+
override def dataType: DataType = child.dataType
183+
184+
// Expected input data type.
185+
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
186+
187+
private val max = AttributeReference("max", child.dataType)()
188+
189+
override val bufferAttributes = max :: Nil
190+
191+
override val initialValues = Seq(
192+
/* max = */ Literal.create(null, child.dataType)
193+
)
194+
195+
override val updateExpressions = Seq(
196+
/* max = */ If(IsNull(child), max, If(IsNull(max), child, Greatest(Seq(max, child))))
197+
)
198+
199+
override val mergeExpressions = {
200+
val greatest = Greatest(Seq(max.left, max.right))
201+
Seq(
202+
/* max = */ If(IsNull(max.right), max.left, If(IsNull(max.left), max.right, greatest))
203+
)
204+
}
205+
206+
override val evaluateExpression = max
207+
}
208+
209+
case class Min(child: Expression) extends AlgebraicAggregate {
210+
211+
override def children: Seq[Expression] = child :: Nil
212+
213+
override def nullable: Boolean = true
214+
215+
// Return data type.
216+
override def dataType: DataType = child.dataType
217+
218+
// Expected input data type.
219+
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
220+
221+
private val min = AttributeReference("min", child.dataType)()
222+
223+
override val bufferAttributes = min :: Nil
224+
225+
override val initialValues = Seq(
226+
/* min = */ Literal.create(null, child.dataType)
227+
)
228+
229+
override val updateExpressions = Seq(
230+
/* min = */ If(IsNull(child), min, If(IsNull(min), child, Least(Seq(min, child))))
231+
)
232+
233+
override val mergeExpressions = {
234+
val least = Least(Seq(min.left, min.right))
235+
Seq(
236+
/* min = */ If(IsNull(min.right), min.left, If(IsNull(min.left), min.right, least))
237+
)
238+
}
239+
240+
override val evaluateExpression = min
241+
}
242+
243+
case class Sum(child: Expression) extends AlgebraicAggregate {
244+
245+
override def children: Seq[Expression] = child :: Nil
246+
247+
override def nullable: Boolean = true
248+
249+
// Return data type.
250+
override def dataType: DataType = resultType
251+
252+
// Expected input data type.
253+
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType))
254+
255+
private val resultType = child.dataType match {
256+
case DecimalType.Fixed(precision, scale) =>
257+
DecimalType(precision + 4, scale + 4)
258+
case DecimalType.Unlimited => DecimalType.Unlimited
259+
case _ => DoubleType
260+
}
261+
262+
private val sumDataType = child.dataType match {
263+
case _ @ DecimalType() => DecimalType.Unlimited
264+
case _ => DoubleType
265+
}
266+
267+
private val currentSum = AttributeReference("currentSum", sumDataType)()
268+
269+
private val zero = Cast(Literal(0), sumDataType)
270+
271+
override val bufferAttributes = currentSum :: Nil
272+
273+
override val initialValues = Seq(
274+
/* currentSum = */ Literal.create(null, sumDataType)
275+
)
276+
277+
override val updateExpressions = Seq(
278+
/* currentSum = */
279+
Coalesce(Seq(Add(Coalesce(Seq(currentSum, zero)), Cast(child, sumDataType)), currentSum))
280+
)
281+
282+
override val mergeExpressions = {
283+
val add = Add(Coalesce(Seq(currentSum.left, zero)), Cast(currentSum.right, sumDataType))
284+
Seq(
285+
/* currentSum = */
286+
Coalesce(Seq(add, currentSum.left))
287+
)
288+
}
289+
290+
override val evaluateExpression = Cast(currentSum, resultType)
291+
}

0 commit comments

Comments
 (0)