-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-27653][SQL] Add max_by() and min_by() SQL aggregate functions #24557
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,128 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.sql.catalyst.expressions.aggregate | ||
|
|
||
| import org.apache.spark.sql.catalyst.analysis.TypeCheckResult | ||
| import org.apache.spark.sql.catalyst.dsl.expressions._ | ||
| import org.apache.spark.sql.catalyst.expressions._ | ||
| import org.apache.spark.sql.catalyst.util.TypeUtils | ||
| import org.apache.spark.sql.types._ | ||
|
|
||
| /** | ||
| * The shared abstract superclass for `MaxBy` and `MinBy` SQL aggregate functions. | ||
| */ | ||
| abstract class MaxMinBy extends DeclarativeAggregate { | ||
|
|
||
| def valueExpr: Expression | ||
| def orderingExpr: Expression | ||
|
|
||
| protected def funcName: String | ||
| // The predicate compares two ordering values. | ||
| protected def predicate(oldExpr: Expression, newExpr: Expression): Expression | ||
| // The arithmetic expression returns greatest/least value of all parameters. | ||
| // Used to pick up updated ordering value. | ||
| protected def orderingUpdater(oldExpr: Expression, newExpr: Expression): Expression | ||
|
|
||
| override def children: Seq[Expression] = valueExpr :: orderingExpr :: Nil | ||
|
|
||
| override def nullable: Boolean = true | ||
|
|
||
| // Return data type. | ||
| override def dataType: DataType = valueExpr.dataType | ||
|
|
||
| override def checkInputDataTypes(): TypeCheckResult = | ||
| TypeUtils.checkForOrderingExpr(orderingExpr.dataType, s"function $funcName") | ||
|
|
||
| // The attributes used to keep extremum (max or min) and associated aggregated values. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah that's a good point. Shall we call it
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good for me. +1 |
||
| private lazy val extremumOrdering = | ||
| AttributeReference("extremumOrdering", orderingExpr.dataType)() | ||
| private lazy val valueWithExtremumOrdering = | ||
| AttributeReference("valueWithExtremumOrdering", valueExpr.dataType)() | ||
|
|
||
| override lazy val aggBufferAttributes: Seq[AttributeReference] = | ||
| valueWithExtremumOrdering :: extremumOrdering :: Nil | ||
|
|
||
| private lazy val nullValue = Literal.create(null, valueExpr.dataType) | ||
| private lazy val nullOrdering = Literal.create(null, orderingExpr.dataType) | ||
|
|
||
| override lazy val initialValues: Seq[Literal] = Seq( | ||
| /* valueWithExtremumOrdering = */ nullValue, | ||
| /* extremumOrdering = */ nullOrdering | ||
| ) | ||
|
|
||
| override lazy val updateExpressions: Seq[Expression] = Seq( | ||
| /* valueWithExtremumOrdering = */ | ||
| CaseWhen( | ||
| (extremumOrdering.isNull && orderingExpr.isNull, nullValue) :: | ||
| (extremumOrdering.isNull, valueExpr) :: | ||
| (orderingExpr.isNull, valueWithExtremumOrdering) :: Nil, | ||
| If(predicate(extremumOrdering, orderingExpr), valueWithExtremumOrdering, valueExpr) | ||
| ), | ||
| /* extremumOrdering = */ orderingUpdater(extremumOrdering, orderingExpr) | ||
| ) | ||
|
|
||
| override lazy val mergeExpressions: Seq[Expression] = Seq( | ||
| /* valueWithExtremumOrdering = */ | ||
| CaseWhen( | ||
| (extremumOrdering.left.isNull && extremumOrdering.right.isNull, nullValue) :: | ||
| (extremumOrdering.left.isNull, valueWithExtremumOrdering.right) :: | ||
| (extremumOrdering.right.isNull, valueWithExtremumOrdering.left) :: Nil, | ||
| If(predicate(extremumOrdering.left, extremumOrdering.right), | ||
| valueWithExtremumOrdering.left, valueWithExtremumOrdering.right) | ||
| ), | ||
| /* extremumOrdering = */ orderingUpdater(extremumOrdering.left, extremumOrdering.right) | ||
| ) | ||
|
|
||
| override lazy val evaluateExpression: AttributeReference = valueWithExtremumOrdering | ||
| } | ||
|
|
||
| @ExpressionDescription( | ||
| usage = "_FUNC_(x, y) - Returns the value of `x` associated with the maximum value of `y`.", | ||
| examples = """ | ||
| Examples: | ||
| > SELECT _FUNC_(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS tab(x, y); | ||
| b | ||
| """, | ||
| since = "3.0") | ||
| case class MaxBy(valueExpr: Expression, orderingExpr: Expression) extends MaxMinBy { | ||
| override protected def funcName: String = "max_by" | ||
|
|
||
| override protected def predicate(oldExpr: Expression, newExpr: Expression): Expression = | ||
| oldExpr > newExpr | ||
|
|
||
| override protected def orderingUpdater(oldExpr: Expression, newExpr: Expression): Expression = | ||
| greatest(oldExpr, newExpr) | ||
| } | ||
|
|
||
| @ExpressionDescription( | ||
| usage = "_FUNC_(x, y) - Returns the value of `x` associated with the minimum value of `y`.", | ||
| examples = """ | ||
| Examples: | ||
| > SELECT _FUNC_(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS tab(x, y); | ||
| a | ||
| """, | ||
| since = "3.0") | ||
| case class MinBy(valueExpr: Expression, orderingExpr: Expression) extends MaxMinBy { | ||
| override protected def funcName: String = "min_by" | ||
|
|
||
| override protected def predicate(oldExpr: Expression, newExpr: Expression): Expression = | ||
| oldExpr < newExpr | ||
|
|
||
| override protected def orderingUpdater(oldExpr: Expression, newExpr: Expression): Expression = | ||
| least(oldExpr, newExpr) | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -782,4 +782,116 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { | |
| val countAndDistinct = df.select(count("*"), countDistinct("*")) | ||
| checkAnswer(countAndDistinct, Row(100000, 100)) | ||
| } | ||
|
|
||
| test("max_by") { | ||
| val yearOfMaxEarnings = | ||
| sql("SELECT course, max_by(year, earnings) FROM courseSales GROUP BY course") | ||
| checkAnswer(yearOfMaxEarnings, Row("dotNET", 2013) :: Row("Java", 2013) :: Nil) | ||
|
|
||
| checkAnswer( | ||
| sql("SELECT max_by(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS tab(x, y)"), | ||
| Row("b") :: Nil | ||
| ) | ||
|
|
||
| checkAnswer( | ||
| sql("SELECT max_by(x, y) FROM VALUES (('a', 10)), (('b', null)), (('c', 20)) AS tab(x, y)"), | ||
| Row("c") :: Nil | ||
| ) | ||
|
|
||
| checkAnswer( | ||
| sql("SELECT max_by(x, y) FROM VALUES (('a', null)), (('b', null)), (('c', 20)) AS tab(x, y)"), | ||
| Row("c") :: Nil | ||
| ) | ||
|
|
||
| checkAnswer( | ||
| sql("SELECT max_by(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', null)) AS tab(x, y)"), | ||
| Row("b") :: Nil | ||
| ) | ||
|
|
||
| checkAnswer( | ||
| sql("SELECT max_by(x, y) FROM VALUES (('a', null)), (('b', null)) AS tab(x, y)"), | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This returns also returns
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This makes sense if you think of this function as being semantically equivalent to |
||
| Row(null) :: Nil | ||
| ) | ||
|
|
||
| // structs as ordering value. | ||
| checkAnswer( | ||
| sql("select max_by(x, y) FROM VALUES (('a', (10, 20))), (('b', (10, 50))), " + | ||
| "(('c', (10, 60))) AS tab(x, y)"), | ||
| Row("c") :: Nil | ||
| ) | ||
|
|
||
| checkAnswer( | ||
| sql("select max_by(x, y) FROM VALUES (('a', (10, 20))), (('b', (10, 50))), " + | ||
| "(('c', null)) AS tab(x, y)"), | ||
| Row("b") :: Nil | ||
| ) | ||
|
|
||
| withTempView("tempView") { | ||
| val dfWithMap = Seq((0, "a"), (1, "b"), (2, "c")) | ||
| .toDF("x", "y") | ||
| .select($"x", map($"x", $"y").as("y")) | ||
| .createOrReplaceTempView("tempView") | ||
| val error = intercept[AnalysisException] { | ||
| sql("SELECT max_by(x, y) FROM tempView").show | ||
| } | ||
| assert( | ||
| error.message.contains("function max_by does not support ordering on type map<int,string>")) | ||
| } | ||
| } | ||
|
|
||
| test("min_by") { | ||
| val yearOfMinEarnings = | ||
| sql("SELECT course, min_by(year, earnings) FROM courseSales GROUP BY course") | ||
| checkAnswer(yearOfMinEarnings, Row("dotNET", 2012) :: Row("Java", 2012) :: Nil) | ||
|
|
||
| checkAnswer( | ||
| sql("SELECT min_by(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS tab(x, y)"), | ||
| Row("a") :: Nil | ||
| ) | ||
|
|
||
| checkAnswer( | ||
| sql("SELECT min_by(x, y) FROM VALUES (('a', 10)), (('b', null)), (('c', 20)) AS tab(x, y)"), | ||
| Row("a") :: Nil | ||
| ) | ||
|
|
||
| checkAnswer( | ||
| sql("SELECT min_by(x, y) FROM VALUES (('a', null)), (('b', null)), (('c', 20)) AS tab(x, y)"), | ||
| Row("c") :: Nil | ||
| ) | ||
|
|
||
| checkAnswer( | ||
| sql("SELECT min_by(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', null)) AS tab(x, y)"), | ||
| Row("a") :: Nil | ||
| ) | ||
|
|
||
| checkAnswer( | ||
| sql("SELECT min_by(x, y) FROM VALUES (('a', null)), (('b', null)) AS tab(x, y)"), | ||
| Row(null) :: Nil | ||
| ) | ||
|
|
||
| // structs as ordering value. | ||
| checkAnswer( | ||
| sql("select min_by(x, y) FROM VALUES (('a', (10, 20))), (('b', (10, 50))), " + | ||
| "(('c', (10, 60))) AS tab(x, y)"), | ||
| Row("a") :: Nil | ||
| ) | ||
|
|
||
| checkAnswer( | ||
| sql("select min_by(x, y) FROM VALUES (('a', null)), (('b', (10, 50))), " + | ||
| "(('c', (10, 60))) AS tab(x, y)"), | ||
| Row("b") :: Nil | ||
| ) | ||
|
|
||
| withTempView("tempView") { | ||
| val dfWithMap = Seq((0, "a"), (1, "b"), (2, "c")) | ||
| .toDF("x", "y") | ||
| .select($"x", map($"x", $"y").as("y")) | ||
| .createOrReplaceTempView("tempView") | ||
| val error = intercept[AnalysisException] { | ||
| sql("SELECT min_by(x, y) FROM tempView").show | ||
| } | ||
| assert( | ||
| error.message.contains("function min_by does not support ordering on type map<int,string>")) | ||
| } | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since we import the expression DSL, can we use DSL to build the expression tree in this file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some can, like
And,IsNull. Some can't, likeCaseWhen,If.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I rewrite
AndandIsNullusing DSL.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we can add DSL for CaseWhen and If. Not a blocker here.