diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index d725ef5da06e..986662c95114 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -289,8 +289,10 @@ object FunctionRegistry { expression[Last]("last"), expression[Last]("last_value"), expression[Max]("max"), + expression[MaxBy]("max_by"), expression[Average]("mean"), expression[Min]("min"), + expression[MinBy]("min_by"), expression[Percentile]("percentile"), expression[Skewness]("skewness"), expression[ApproximatePercentile]("percentile_approx"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxByAndMinBy.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxByAndMinBy.scala new file mode 100644 index 000000000000..c7fdb15130c4 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxByAndMinBy.scala @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.types._ + +/** + * The shared abstract superclass for `MaxBy` and `MinBy` SQL aggregate functions. + */ +abstract class MaxMinBy extends DeclarativeAggregate { + + def valueExpr: Expression + def orderingExpr: Expression + + protected def funcName: String + // The predicate compares two ordering values. + protected def predicate(oldExpr: Expression, newExpr: Expression): Expression + // The arithmetic expression returns greatest/least value of all parameters. + // Used to pick up updated ordering value. + protected def orderingUpdater(oldExpr: Expression, newExpr: Expression): Expression + + override def children: Seq[Expression] = valueExpr :: orderingExpr :: Nil + + override def nullable: Boolean = true + + // Return data type. + override def dataType: DataType = valueExpr.dataType + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForOrderingExpr(orderingExpr.dataType, s"function $funcName") + + // The attributes used to keep extremum (max or min) and associated aggregated values. + private lazy val extremumOrdering = + AttributeReference("extremumOrdering", orderingExpr.dataType)() + private lazy val valueWithExtremumOrdering = + AttributeReference("valueWithExtremumOrdering", valueExpr.dataType)() + + override lazy val aggBufferAttributes: Seq[AttributeReference] = + valueWithExtremumOrdering :: extremumOrdering :: Nil + + private lazy val nullValue = Literal.create(null, valueExpr.dataType) + private lazy val nullOrdering = Literal.create(null, orderingExpr.dataType) + + override lazy val initialValues: Seq[Literal] = Seq( + /* valueWithExtremumOrdering = */ nullValue, + /* extremumOrdering = */ nullOrdering + ) + + override lazy val updateExpressions: Seq[Expression] = Seq( + /* valueWithExtremumOrdering = */ + CaseWhen( + (extremumOrdering.isNull && orderingExpr.isNull, nullValue) :: + (extremumOrdering.isNull, valueExpr) :: + (orderingExpr.isNull, valueWithExtremumOrdering) :: Nil, + If(predicate(extremumOrdering, orderingExpr), valueWithExtremumOrdering, valueExpr) + ), + /* extremumOrdering = */ orderingUpdater(extremumOrdering, orderingExpr) + ) + + override lazy val mergeExpressions: Seq[Expression] = Seq( + /* valueWithExtremumOrdering = */ + CaseWhen( + (extremumOrdering.left.isNull && extremumOrdering.right.isNull, nullValue) :: + (extremumOrdering.left.isNull, valueWithExtremumOrdering.right) :: + (extremumOrdering.right.isNull, valueWithExtremumOrdering.left) :: Nil, + If(predicate(extremumOrdering.left, extremumOrdering.right), + valueWithExtremumOrdering.left, valueWithExtremumOrdering.right) + ), + /* extremumOrdering = */ orderingUpdater(extremumOrdering.left, extremumOrdering.right) + ) + + override lazy val evaluateExpression: AttributeReference = valueWithExtremumOrdering +} + +@ExpressionDescription( + usage = "_FUNC_(x, y) - Returns the value of `x` associated with the maximum value of `y`.", + examples = """ + Examples: + > SELECT _FUNC_(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS tab(x, y); + b + """, + since = "3.0") +case class MaxBy(valueExpr: Expression, orderingExpr: Expression) extends MaxMinBy { + override protected def funcName: String = "max_by" + + override protected def predicate(oldExpr: Expression, newExpr: Expression): Expression = + oldExpr > newExpr + + override protected def orderingUpdater(oldExpr: Expression, newExpr: Expression): Expression = + greatest(oldExpr, newExpr) +} + +@ExpressionDescription( + usage = "_FUNC_(x, y) - Returns the value of `x` associated with the minimum value of `y`.", + examples = """ + Examples: + > SELECT _FUNC_(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS tab(x, y); + a + """, + since = "3.0") +case class MinBy(valueExpr: Expression, orderingExpr: Expression) extends MaxMinBy { + override protected def funcName: String = "min_by" + + override protected def predicate(oldExpr: Expression, newExpr: Expression): Expression = + oldExpr < newExpr + + override protected def orderingUpdater(oldExpr: Expression, newExpr: Expression): Expression = + least(oldExpr, newExpr) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 97aaa1b584af..d89ecc22a7c0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -782,4 +782,116 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { val countAndDistinct = df.select(count("*"), countDistinct("*")) checkAnswer(countAndDistinct, Row(100000, 100)) } + + test("max_by") { + val yearOfMaxEarnings = + sql("SELECT course, max_by(year, earnings) FROM courseSales GROUP BY course") + checkAnswer(yearOfMaxEarnings, Row("dotNET", 2013) :: Row("Java", 2013) :: Nil) + + checkAnswer( + sql("SELECT max_by(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS tab(x, y)"), + Row("b") :: Nil + ) + + checkAnswer( + sql("SELECT max_by(x, y) FROM VALUES (('a', 10)), (('b', null)), (('c', 20)) AS tab(x, y)"), + Row("c") :: Nil + ) + + checkAnswer( + sql("SELECT max_by(x, y) FROM VALUES (('a', null)), (('b', null)), (('c', 20)) AS tab(x, y)"), + Row("c") :: Nil + ) + + checkAnswer( + sql("SELECT max_by(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', null)) AS tab(x, y)"), + Row("b") :: Nil + ) + + checkAnswer( + sql("SELECT max_by(x, y) FROM VALUES (('a', null)), (('b', null)) AS tab(x, y)"), + Row(null) :: Nil + ) + + // structs as ordering value. + checkAnswer( + sql("select max_by(x, y) FROM VALUES (('a', (10, 20))), (('b', (10, 50))), " + + "(('c', (10, 60))) AS tab(x, y)"), + Row("c") :: Nil + ) + + checkAnswer( + sql("select max_by(x, y) FROM VALUES (('a', (10, 20))), (('b', (10, 50))), " + + "(('c', null)) AS tab(x, y)"), + Row("b") :: Nil + ) + + withTempView("tempView") { + val dfWithMap = Seq((0, "a"), (1, "b"), (2, "c")) + .toDF("x", "y") + .select($"x", map($"x", $"y").as("y")) + .createOrReplaceTempView("tempView") + val error = intercept[AnalysisException] { + sql("SELECT max_by(x, y) FROM tempView").show + } + assert( + error.message.contains("function max_by does not support ordering on type map")) + } + } + + test("min_by") { + val yearOfMinEarnings = + sql("SELECT course, min_by(year, earnings) FROM courseSales GROUP BY course") + checkAnswer(yearOfMinEarnings, Row("dotNET", 2012) :: Row("Java", 2012) :: Nil) + + checkAnswer( + sql("SELECT min_by(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS tab(x, y)"), + Row("a") :: Nil + ) + + checkAnswer( + sql("SELECT min_by(x, y) FROM VALUES (('a', 10)), (('b', null)), (('c', 20)) AS tab(x, y)"), + Row("a") :: Nil + ) + + checkAnswer( + sql("SELECT min_by(x, y) FROM VALUES (('a', null)), (('b', null)), (('c', 20)) AS tab(x, y)"), + Row("c") :: Nil + ) + + checkAnswer( + sql("SELECT min_by(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', null)) AS tab(x, y)"), + Row("a") :: Nil + ) + + checkAnswer( + sql("SELECT min_by(x, y) FROM VALUES (('a', null)), (('b', null)) AS tab(x, y)"), + Row(null) :: Nil + ) + + // structs as ordering value. + checkAnswer( + sql("select min_by(x, y) FROM VALUES (('a', (10, 20))), (('b', (10, 50))), " + + "(('c', (10, 60))) AS tab(x, y)"), + Row("a") :: Nil + ) + + checkAnswer( + sql("select min_by(x, y) FROM VALUES (('a', null)), (('b', (10, 50))), " + + "(('c', (10, 60))) AS tab(x, y)"), + Row("b") :: Nil + ) + + withTempView("tempView") { + val dfWithMap = Seq((0, "a"), (1, "b"), (2, "c")) + .toDF("x", "y") + .select($"x", map($"x", $"y").as("y")) + .createOrReplaceTempView("tempView") + val error = intercept[AnalysisException] { + sql("SELECT min_by(x, y) FROM tempView").show + } + assert( + error.message.contains("function min_by does not support ordering on type map")) + } + } }