Skip to content

Commit d169b0a

Browse files
viiryacloud-fan
authored andcommitted
[SPARK-27653][SQL] Add max_by() and min_by() SQL aggregate functions
## What changes were proposed in this pull request? This PR goes to add `max_by()` and `min_by()` SQL aggregate functions. Quoting from the [Presto docs](https://prestodb.github.io/docs/current/functions/aggregate.html#max_by) > max_by(x, y) → [same as x] > Returns the value of x associated with the maximum value of y over all input values. `min_by()` works similarly. ## How was this patch tested? Added tests. Closes apache#24557 from viirya/SPARK-27653. Authored-by: Liang-Chi Hsieh <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 126310c commit d169b0a

File tree

3 files changed

+242
-0
lines changed

3 files changed

+242
-0
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,8 +289,10 @@ object FunctionRegistry {
289289
expression[Last]("last"),
290290
expression[Last]("last_value"),
291291
expression[Max]("max"),
292+
expression[MaxBy]("max_by"),
292293
expression[Average]("mean"),
293294
expression[Min]("min"),
295+
expression[MinBy]("min_by"),
294296
expression[Percentile]("percentile"),
295297
expression[Skewness]("skewness"),
296298
expression[ApproximatePercentile]("percentile_approx"),
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.expressions.aggregate
19+
20+
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
21+
import org.apache.spark.sql.catalyst.dsl.expressions._
22+
import org.apache.spark.sql.catalyst.expressions._
23+
import org.apache.spark.sql.catalyst.util.TypeUtils
24+
import org.apache.spark.sql.types._
25+
26+
/**
27+
* The shared abstract superclass for `MaxBy` and `MinBy` SQL aggregate functions.
28+
*/
29+
abstract class MaxMinBy extends DeclarativeAggregate {
30+
31+
def valueExpr: Expression
32+
def orderingExpr: Expression
33+
34+
protected def funcName: String
35+
// The predicate compares two ordering values.
36+
protected def predicate(oldExpr: Expression, newExpr: Expression): Expression
37+
// The arithmetic expression returns greatest/least value of all parameters.
38+
// Used to pick up updated ordering value.
39+
protected def orderingUpdater(oldExpr: Expression, newExpr: Expression): Expression
40+
41+
override def children: Seq[Expression] = valueExpr :: orderingExpr :: Nil
42+
43+
override def nullable: Boolean = true
44+
45+
// Return data type.
46+
override def dataType: DataType = valueExpr.dataType
47+
48+
override def checkInputDataTypes(): TypeCheckResult =
49+
TypeUtils.checkForOrderingExpr(orderingExpr.dataType, s"function $funcName")
50+
51+
// The attributes used to keep extremum (max or min) and associated aggregated values.
52+
private lazy val extremumOrdering =
53+
AttributeReference("extremumOrdering", orderingExpr.dataType)()
54+
private lazy val valueWithExtremumOrdering =
55+
AttributeReference("valueWithExtremumOrdering", valueExpr.dataType)()
56+
57+
override lazy val aggBufferAttributes: Seq[AttributeReference] =
58+
valueWithExtremumOrdering :: extremumOrdering :: Nil
59+
60+
private lazy val nullValue = Literal.create(null, valueExpr.dataType)
61+
private lazy val nullOrdering = Literal.create(null, orderingExpr.dataType)
62+
63+
override lazy val initialValues: Seq[Literal] = Seq(
64+
/* valueWithExtremumOrdering = */ nullValue,
65+
/* extremumOrdering = */ nullOrdering
66+
)
67+
68+
override lazy val updateExpressions: Seq[Expression] = Seq(
69+
/* valueWithExtremumOrdering = */
70+
CaseWhen(
71+
(extremumOrdering.isNull && orderingExpr.isNull, nullValue) ::
72+
(extremumOrdering.isNull, valueExpr) ::
73+
(orderingExpr.isNull, valueWithExtremumOrdering) :: Nil,
74+
If(predicate(extremumOrdering, orderingExpr), valueWithExtremumOrdering, valueExpr)
75+
),
76+
/* extremumOrdering = */ orderingUpdater(extremumOrdering, orderingExpr)
77+
)
78+
79+
override lazy val mergeExpressions: Seq[Expression] = Seq(
80+
/* valueWithExtremumOrdering = */
81+
CaseWhen(
82+
(extremumOrdering.left.isNull && extremumOrdering.right.isNull, nullValue) ::
83+
(extremumOrdering.left.isNull, valueWithExtremumOrdering.right) ::
84+
(extremumOrdering.right.isNull, valueWithExtremumOrdering.left) :: Nil,
85+
If(predicate(extremumOrdering.left, extremumOrdering.right),
86+
valueWithExtremumOrdering.left, valueWithExtremumOrdering.right)
87+
),
88+
/* extremumOrdering = */ orderingUpdater(extremumOrdering.left, extremumOrdering.right)
89+
)
90+
91+
override lazy val evaluateExpression: AttributeReference = valueWithExtremumOrdering
92+
}
93+
94+
@ExpressionDescription(
95+
usage = "_FUNC_(x, y) - Returns the value of `x` associated with the maximum value of `y`.",
96+
examples = """
97+
Examples:
98+
> SELECT _FUNC_(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS tab(x, y);
99+
b
100+
""",
101+
since = "3.0")
102+
case class MaxBy(valueExpr: Expression, orderingExpr: Expression) extends MaxMinBy {
103+
override protected def funcName: String = "max_by"
104+
105+
override protected def predicate(oldExpr: Expression, newExpr: Expression): Expression =
106+
oldExpr > newExpr
107+
108+
override protected def orderingUpdater(oldExpr: Expression, newExpr: Expression): Expression =
109+
greatest(oldExpr, newExpr)
110+
}
111+
112+
@ExpressionDescription(
113+
usage = "_FUNC_(x, y) - Returns the value of `x` associated with the minimum value of `y`.",
114+
examples = """
115+
Examples:
116+
> SELECT _FUNC_(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS tab(x, y);
117+
a
118+
""",
119+
since = "3.0")
120+
case class MinBy(valueExpr: Expression, orderingExpr: Expression) extends MaxMinBy {
121+
override protected def funcName: String = "min_by"
122+
123+
override protected def predicate(oldExpr: Expression, newExpr: Expression): Expression =
124+
oldExpr < newExpr
125+
126+
override protected def orderingUpdater(oldExpr: Expression, newExpr: Expression): Expression =
127+
least(oldExpr, newExpr)
128+
}

sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -782,4 +782,116 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
782782
val countAndDistinct = df.select(count("*"), countDistinct("*"))
783783
checkAnswer(countAndDistinct, Row(100000, 100))
784784
}
785+
786+
test("max_by") {
787+
val yearOfMaxEarnings =
788+
sql("SELECT course, max_by(year, earnings) FROM courseSales GROUP BY course")
789+
checkAnswer(yearOfMaxEarnings, Row("dotNET", 2013) :: Row("Java", 2013) :: Nil)
790+
791+
checkAnswer(
792+
sql("SELECT max_by(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS tab(x, y)"),
793+
Row("b") :: Nil
794+
)
795+
796+
checkAnswer(
797+
sql("SELECT max_by(x, y) FROM VALUES (('a', 10)), (('b', null)), (('c', 20)) AS tab(x, y)"),
798+
Row("c") :: Nil
799+
)
800+
801+
checkAnswer(
802+
sql("SELECT max_by(x, y) FROM VALUES (('a', null)), (('b', null)), (('c', 20)) AS tab(x, y)"),
803+
Row("c") :: Nil
804+
)
805+
806+
checkAnswer(
807+
sql("SELECT max_by(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', null)) AS tab(x, y)"),
808+
Row("b") :: Nil
809+
)
810+
811+
checkAnswer(
812+
sql("SELECT max_by(x, y) FROM VALUES (('a', null)), (('b', null)) AS tab(x, y)"),
813+
Row(null) :: Nil
814+
)
815+
816+
// structs as ordering value.
817+
checkAnswer(
818+
sql("select max_by(x, y) FROM VALUES (('a', (10, 20))), (('b', (10, 50))), " +
819+
"(('c', (10, 60))) AS tab(x, y)"),
820+
Row("c") :: Nil
821+
)
822+
823+
checkAnswer(
824+
sql("select max_by(x, y) FROM VALUES (('a', (10, 20))), (('b', (10, 50))), " +
825+
"(('c', null)) AS tab(x, y)"),
826+
Row("b") :: Nil
827+
)
828+
829+
withTempView("tempView") {
830+
val dfWithMap = Seq((0, "a"), (1, "b"), (2, "c"))
831+
.toDF("x", "y")
832+
.select($"x", map($"x", $"y").as("y"))
833+
.createOrReplaceTempView("tempView")
834+
val error = intercept[AnalysisException] {
835+
sql("SELECT max_by(x, y) FROM tempView").show
836+
}
837+
assert(
838+
error.message.contains("function max_by does not support ordering on type map<int,string>"))
839+
}
840+
}
841+
842+
test("min_by") {
843+
val yearOfMinEarnings =
844+
sql("SELECT course, min_by(year, earnings) FROM courseSales GROUP BY course")
845+
checkAnswer(yearOfMinEarnings, Row("dotNET", 2012) :: Row("Java", 2012) :: Nil)
846+
847+
checkAnswer(
848+
sql("SELECT min_by(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS tab(x, y)"),
849+
Row("a") :: Nil
850+
)
851+
852+
checkAnswer(
853+
sql("SELECT min_by(x, y) FROM VALUES (('a', 10)), (('b', null)), (('c', 20)) AS tab(x, y)"),
854+
Row("a") :: Nil
855+
)
856+
857+
checkAnswer(
858+
sql("SELECT min_by(x, y) FROM VALUES (('a', null)), (('b', null)), (('c', 20)) AS tab(x, y)"),
859+
Row("c") :: Nil
860+
)
861+
862+
checkAnswer(
863+
sql("SELECT min_by(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', null)) AS tab(x, y)"),
864+
Row("a") :: Nil
865+
)
866+
867+
checkAnswer(
868+
sql("SELECT min_by(x, y) FROM VALUES (('a', null)), (('b', null)) AS tab(x, y)"),
869+
Row(null) :: Nil
870+
)
871+
872+
// structs as ordering value.
873+
checkAnswer(
874+
sql("select min_by(x, y) FROM VALUES (('a', (10, 20))), (('b', (10, 50))), " +
875+
"(('c', (10, 60))) AS tab(x, y)"),
876+
Row("a") :: Nil
877+
)
878+
879+
checkAnswer(
880+
sql("select min_by(x, y) FROM VALUES (('a', null)), (('b', (10, 50))), " +
881+
"(('c', (10, 60))) AS tab(x, y)"),
882+
Row("b") :: Nil
883+
)
884+
885+
withTempView("tempView") {
886+
val dfWithMap = Seq((0, "a"), (1, "b"), (2, "c"))
887+
.toDF("x", "y")
888+
.select($"x", map($"x", $"y").as("y"))
889+
.createOrReplaceTempView("tempView")
890+
val error = intercept[AnalysisException] {
891+
sql("SELECT min_by(x, y) FROM tempView").show
892+
}
893+
assert(
894+
error.message.contains("function min_by does not support ordering on type map<int,string>"))
895+
}
896+
}
785897
}

0 commit comments

Comments
 (0)