Skip to content

Commit 8733fbf

Browse files
viiryaVinitha Gankidi
authored andcommitted
[SPARK-27653][SQL] Add max_by() and min_by() SQL aggregate functions
This PR goes to add `max_by()` and `min_by()` SQL aggregate functions. Quoting from the [Presto docs](https://prestodb.github.io/docs/current/functions/aggregate.html#max_by) > max_by(x, y) → [same as x] > Returns the value of x associated with the maximum value of y over all input values. `min_by()` works similarly. Added tests. Closes apache#24557 from viirya/SPARK-27653. Authored-by: Liang-Chi Hsieh <[email protected]> Signed-off-by: Wenchen Fan <[email protected]> (cherry picked from commit d169b0a)
1 parent 08eb192 commit 8733fbf

File tree

4 files changed

+244
-0
lines changed

4 files changed

+244
-0
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,8 +282,10 @@ object FunctionRegistry {
282282
expression[Last]("last"),
283283
expression[Last]("last_value"),
284284
expression[Max]("max"),
285+
expression[MaxBy]("max_by"),
285286
expression[Average]("mean"),
286287
expression[Min]("min"),
288+
expression[MinBy]("min_by"),
287289
expression[Percentile]("percentile"),
288290
expression[Skewness]("skewness"),
289291
expression[ApproximatePercentile]("percentile_approx"),

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,8 @@ package object dsl {
166166
def maxDistinct(e: Expression): Expression = Max(e).toAggregateExpression(isDistinct = true)
167167
def upper(e: Expression): Expression = Upper(e)
168168
def lower(e: Expression): Expression = Lower(e)
169+
def greatest(args: Expression*): Expression = Greatest(args)
170+
def least(args: Expression*): Expression = Least(args)
169171
def sqrt(e: Expression): Expression = Sqrt(e)
170172
def abs(e: Expression): Expression = Abs(e)
171173
def star(names: String*): Expression = names match {
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.expressions.aggregate
19+
20+
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
21+
import org.apache.spark.sql.catalyst.dsl.expressions._
22+
import org.apache.spark.sql.catalyst.expressions._
23+
import org.apache.spark.sql.catalyst.util.TypeUtils
24+
import org.apache.spark.sql.types._
25+
26+
/**
27+
* The shared abstract superclass for `MaxBy` and `MinBy` SQL aggregate functions.
28+
*/
29+
abstract class MaxMinBy extends DeclarativeAggregate {
30+
31+
def valueExpr: Expression
32+
def orderingExpr: Expression
33+
34+
protected def funcName: String
35+
// The predicate compares two ordering values.
36+
protected def predicate(oldExpr: Expression, newExpr: Expression): Expression
37+
// The arithmetic expression returns greatest/least value of all parameters.
38+
// Used to pick up updated ordering value.
39+
protected def orderingUpdater(oldExpr: Expression, newExpr: Expression): Expression
40+
41+
override def children: Seq[Expression] = valueExpr :: orderingExpr :: Nil
42+
43+
override def nullable: Boolean = true
44+
45+
// Return data type.
46+
override def dataType: DataType = valueExpr.dataType
47+
48+
override def checkInputDataTypes(): TypeCheckResult =
49+
TypeUtils.checkForOrderingExpr(orderingExpr.dataType, s"function $funcName")
50+
51+
// The attributes used to keep extremum (max or min) and associated aggregated values.
52+
private lazy val extremumOrdering =
53+
AttributeReference("extremumOrdering", orderingExpr.dataType)()
54+
private lazy val valueWithExtremumOrdering =
55+
AttributeReference("valueWithExtremumOrdering", valueExpr.dataType)()
56+
57+
override lazy val aggBufferAttributes: Seq[AttributeReference] =
58+
valueWithExtremumOrdering :: extremumOrdering :: Nil
59+
60+
private lazy val nullValue = Literal.create(null, valueExpr.dataType)
61+
private lazy val nullOrdering = Literal.create(null, orderingExpr.dataType)
62+
63+
override lazy val initialValues: Seq[Literal] = Seq(
64+
/* valueWithExtremumOrdering = */ nullValue,
65+
/* extremumOrdering = */ nullOrdering
66+
)
67+
68+
override lazy val updateExpressions: Seq[Expression] = Seq(
69+
/* valueWithExtremumOrdering = */
70+
CaseWhen(
71+
(extremumOrdering.isNull && orderingExpr.isNull, nullValue) ::
72+
(extremumOrdering.isNull, valueExpr) ::
73+
(orderingExpr.isNull, valueWithExtremumOrdering) :: Nil,
74+
If(predicate(extremumOrdering, orderingExpr), valueWithExtremumOrdering, valueExpr)
75+
),
76+
/* extremumOrdering = */ orderingUpdater(extremumOrdering, orderingExpr)
77+
)
78+
79+
override lazy val mergeExpressions: Seq[Expression] = Seq(
80+
/* valueWithExtremumOrdering = */
81+
CaseWhen(
82+
(extremumOrdering.left.isNull && extremumOrdering.right.isNull, nullValue) ::
83+
(extremumOrdering.left.isNull, valueWithExtremumOrdering.right) ::
84+
(extremumOrdering.right.isNull, valueWithExtremumOrdering.left) :: Nil,
85+
If(predicate(extremumOrdering.left, extremumOrdering.right),
86+
valueWithExtremumOrdering.left, valueWithExtremumOrdering.right)
87+
),
88+
/* extremumOrdering = */ orderingUpdater(extremumOrdering.left, extremumOrdering.right)
89+
)
90+
91+
override lazy val evaluateExpression: AttributeReference = valueWithExtremumOrdering
92+
}
93+
94+
@ExpressionDescription(
95+
usage = "_FUNC_(x, y) - Returns the value of `x` associated with the maximum value of `y`.",
96+
examples = """
97+
Examples:
98+
> SELECT _FUNC_(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS tab(x, y);
99+
b
100+
""",
101+
since = "3.0")
102+
case class MaxBy(valueExpr: Expression, orderingExpr: Expression) extends MaxMinBy {
103+
override protected def funcName: String = "max_by"
104+
105+
override protected def predicate(oldExpr: Expression, newExpr: Expression): Expression =
106+
oldExpr > newExpr
107+
108+
override protected def orderingUpdater(oldExpr: Expression, newExpr: Expression): Expression =
109+
greatest(oldExpr, newExpr)
110+
}
111+
112+
@ExpressionDescription(
113+
usage = "_FUNC_(x, y) - Returns the value of `x` associated with the minimum value of `y`.",
114+
examples = """
115+
Examples:
116+
> SELECT _FUNC_(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS tab(x, y);
117+
a
118+
""",
119+
since = "3.0")
120+
case class MinBy(valueExpr: Expression, orderingExpr: Expression) extends MaxMinBy {
121+
override protected def funcName: String = "min_by"
122+
123+
override protected def predicate(oldExpr: Expression, newExpr: Expression): Expression =
124+
oldExpr < newExpr
125+
126+
override protected def orderingUpdater(oldExpr: Expression, newExpr: Expression): Expression =
127+
least(oldExpr, newExpr)
128+
}

sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -686,4 +686,116 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
686686
}
687687
}
688688
}
689+
690+
test("max_by") {
691+
val yearOfMaxEarnings =
692+
sql("SELECT course, max_by(year, earnings) FROM courseSales GROUP BY course")
693+
checkAnswer(yearOfMaxEarnings, Row("dotNET", 2013) :: Row("Java", 2013) :: Nil)
694+
695+
checkAnswer(
696+
sql("SELECT max_by(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS tab(x, y)"),
697+
Row("b") :: Nil
698+
)
699+
700+
checkAnswer(
701+
sql("SELECT max_by(x, y) FROM VALUES (('a', 10)), (('b', null)), (('c', 20)) AS tab(x, y)"),
702+
Row("c") :: Nil
703+
)
704+
705+
checkAnswer(
706+
sql("SELECT max_by(x, y) FROM VALUES (('a', null)), (('b', null)), (('c', 20)) AS tab(x, y)"),
707+
Row("c") :: Nil
708+
)
709+
710+
checkAnswer(
711+
sql("SELECT max_by(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', null)) AS tab(x, y)"),
712+
Row("b") :: Nil
713+
)
714+
715+
checkAnswer(
716+
sql("SELECT max_by(x, y) FROM VALUES (('a', null)), (('b', null)) AS tab(x, y)"),
717+
Row(null) :: Nil
718+
)
719+
720+
// structs as ordering value.
721+
checkAnswer(
722+
sql("select max_by(x, y) FROM VALUES (('a', (10, 20))), (('b', (10, 50))), " +
723+
"(('c', (10, 60))) AS tab(x, y)"),
724+
Row("c") :: Nil
725+
)
726+
727+
checkAnswer(
728+
sql("select max_by(x, y) FROM VALUES (('a', (10, 20))), (('b', (10, 50))), " +
729+
"(('c', null)) AS tab(x, y)"),
730+
Row("b") :: Nil
731+
)
732+
733+
withTempView("tempView") {
734+
val dfWithMap = Seq((0, "a"), (1, "b"), (2, "c"))
735+
.toDF("x", "y")
736+
.select($"x", map($"x", $"y").as("y"))
737+
.createOrReplaceTempView("tempView")
738+
val error = intercept[AnalysisException] {
739+
sql("SELECT max_by(x, y) FROM tempView").show
740+
}
741+
assert(
742+
error.message.contains("function max_by does not support ordering on type map<int,string>"))
743+
}
744+
}
745+
746+
test("min_by") {
747+
val yearOfMinEarnings =
748+
sql("SELECT course, min_by(year, earnings) FROM courseSales GROUP BY course")
749+
checkAnswer(yearOfMinEarnings, Row("dotNET", 2012) :: Row("Java", 2012) :: Nil)
750+
751+
checkAnswer(
752+
sql("SELECT min_by(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS tab(x, y)"),
753+
Row("a") :: Nil
754+
)
755+
756+
checkAnswer(
757+
sql("SELECT min_by(x, y) FROM VALUES (('a', 10)), (('b', null)), (('c', 20)) AS tab(x, y)"),
758+
Row("a") :: Nil
759+
)
760+
761+
checkAnswer(
762+
sql("SELECT min_by(x, y) FROM VALUES (('a', null)), (('b', null)), (('c', 20)) AS tab(x, y)"),
763+
Row("c") :: Nil
764+
)
765+
766+
checkAnswer(
767+
sql("SELECT min_by(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', null)) AS tab(x, y)"),
768+
Row("a") :: Nil
769+
)
770+
771+
checkAnswer(
772+
sql("SELECT min_by(x, y) FROM VALUES (('a', null)), (('b', null)) AS tab(x, y)"),
773+
Row(null) :: Nil
774+
)
775+
776+
// structs as ordering value.
777+
checkAnswer(
778+
sql("select min_by(x, y) FROM VALUES (('a', (10, 20))), (('b', (10, 50))), " +
779+
"(('c', (10, 60))) AS tab(x, y)"),
780+
Row("a") :: Nil
781+
)
782+
783+
checkAnswer(
784+
sql("select min_by(x, y) FROM VALUES (('a', null)), (('b', (10, 50))), " +
785+
"(('c', (10, 60))) AS tab(x, y)"),
786+
Row("b") :: Nil
787+
)
788+
789+
withTempView("tempView") {
790+
val dfWithMap = Seq((0, "a"), (1, "b"), (2, "c"))
791+
.toDF("x", "y")
792+
.select($"x", map($"x", $"y").as("y"))
793+
.createOrReplaceTempView("tempView")
794+
val error = intercept[AnalysisException] {
795+
sql("SELECT min_by(x, y) FROM tempView").show
796+
}
797+
assert(
798+
error.message.contains("function min_by does not support ordering on type map<int,string>"))
799+
}
800+
}
689801
}

0 commit comments

Comments
 (0)