Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,10 @@ object FunctionRegistry {
expression[Last]("last"),
expression[Last]("last_value"),
expression[Max]("max"),
expression[MaxBy]("max_by"),
expression[Average]("mean"),
expression[Min]("min"),
expression[MinBy]("min_by"),
expression[Percentile]("percentile"),
expression[Skewness]("skewness"),
expression[ApproximatePercentile]("percentile_approx"),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.dsl.expressions._
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since we import the expression DSL, can we use DSL to build the expression tree in this file?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some can, like And, IsNull. Some can't, like CaseWhen, If.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I rewrite And and IsNull using DSL.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can add DSL for CaseWhen and If. Not a blocker here.

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._

/**
* The shared abstract superclass for `MaxBy` and `MinBy` SQL aggregate functions.
*/
abstract class MaxMinBy extends DeclarativeAggregate {

def valueExpr: Expression
def orderingExpr: Expression

protected def funcName: String
// The predicate compares two ordering values.
protected def predicate(oldExpr: Expression, newExpr: Expression): Expression
// The arithmetic expression returns greatest/least value of all parameters.
// Used to pick up updated ordering value.
protected def orderingUpdater(oldExpr: Expression, newExpr: Expression): Expression

override def children: Seq[Expression] = valueExpr :: orderingExpr :: Nil

override def nullable: Boolean = true

// Return data type.
override def dataType: DataType = valueExpr.dataType

override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForOrderingExpr(orderingExpr.dataType, s"function $funcName")

// The attributes used to keep extremum (max or min) and associated aggregated values.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah that's a good point. Shall we call it extremumOrdering then?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good for me. +1

private lazy val extremumOrdering =
AttributeReference("extremumOrdering", orderingExpr.dataType)()
private lazy val valueWithExtremumOrdering =
AttributeReference("valueWithExtremumOrdering", valueExpr.dataType)()

override lazy val aggBufferAttributes: Seq[AttributeReference] =
valueWithExtremumOrdering :: extremumOrdering :: Nil

private lazy val nullValue = Literal.create(null, valueExpr.dataType)
private lazy val nullOrdering = Literal.create(null, orderingExpr.dataType)

override lazy val initialValues: Seq[Literal] = Seq(
/* valueWithExtremumOrdering = */ nullValue,
/* extremumOrdering = */ nullOrdering
)

override lazy val updateExpressions: Seq[Expression] = Seq(
/* valueWithExtremumOrdering = */
CaseWhen(
(extremumOrdering.isNull && orderingExpr.isNull, nullValue) ::
(extremumOrdering.isNull, valueExpr) ::
(orderingExpr.isNull, valueWithExtremumOrdering) :: Nil,
If(predicate(extremumOrdering, orderingExpr), valueWithExtremumOrdering, valueExpr)
),
/* extremumOrdering = */ orderingUpdater(extremumOrdering, orderingExpr)
)

override lazy val mergeExpressions: Seq[Expression] = Seq(
/* valueWithExtremumOrdering = */
CaseWhen(
(extremumOrdering.left.isNull && extremumOrdering.right.isNull, nullValue) ::
(extremumOrdering.left.isNull, valueWithExtremumOrdering.right) ::
(extremumOrdering.right.isNull, valueWithExtremumOrdering.left) :: Nil,
If(predicate(extremumOrdering.left, extremumOrdering.right),
valueWithExtremumOrdering.left, valueWithExtremumOrdering.right)
),
/* extremumOrdering = */ orderingUpdater(extremumOrdering.left, extremumOrdering.right)
)

override lazy val evaluateExpression: AttributeReference = valueWithExtremumOrdering
}

@ExpressionDescription(
usage = "_FUNC_(x, y) - Returns the value of `x` associated with the maximum value of `y`.",
examples = """
Examples:
> SELECT _FUNC_(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS tab(x, y);
b
""",
since = "3.0")
case class MaxBy(valueExpr: Expression, orderingExpr: Expression) extends MaxMinBy {
override protected def funcName: String = "max_by"

override protected def predicate(oldExpr: Expression, newExpr: Expression): Expression =
oldExpr > newExpr

override protected def orderingUpdater(oldExpr: Expression, newExpr: Expression): Expression =
greatest(oldExpr, newExpr)
}

@ExpressionDescription(
usage = "_FUNC_(x, y) - Returns the value of `x` associated with the minimum value of `y`.",
examples = """
Examples:
> SELECT _FUNC_(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS tab(x, y);
a
""",
since = "3.0")
case class MinBy(valueExpr: Expression, orderingExpr: Expression) extends MaxMinBy {
override protected def funcName: String = "min_by"

override protected def predicate(oldExpr: Expression, newExpr: Expression): Expression =
oldExpr < newExpr

override protected def orderingUpdater(oldExpr: Expression, newExpr: Expression): Expression =
least(oldExpr, newExpr)
}
Original file line number Diff line number Diff line change
Expand Up @@ -782,4 +782,116 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
val countAndDistinct = df.select(count("*"), countDistinct("*"))
checkAnswer(countAndDistinct, Row(100000, 100))
}

test("max_by") {
val yearOfMaxEarnings =
sql("SELECT course, max_by(year, earnings) FROM courseSales GROUP BY course")
checkAnswer(yearOfMaxEarnings, Row("dotNET", 2013) :: Row("Java", 2013) :: Nil)

checkAnswer(
sql("SELECT max_by(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS tab(x, y)"),
Row("b") :: Nil
)

checkAnswer(
sql("SELECT max_by(x, y) FROM VALUES (('a', 10)), (('b', null)), (('c', 20)) AS tab(x, y)"),
Row("c") :: Nil
)

checkAnswer(
sql("SELECT max_by(x, y) FROM VALUES (('a', null)), (('b', null)), (('c', 20)) AS tab(x, y)"),
Row("c") :: Nil
)

checkAnswer(
sql("SELECT max_by(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', null)) AS tab(x, y)"),
Row("b") :: Nil
)

checkAnswer(
sql("SELECT max_by(x, y) FROM VALUES (('a', null)), (('b', null)) AS tab(x, y)"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This returns null because all values of the ordering column are null? That seems to match Presto behavior:

SELECT max_by(x, y) FROM (
  VALUES
    ('a', null),
    ('b', null)
) AS tab (x, y)

also returns null in Presto 👍

Copy link
Contributor

@JoshRosen JoshRosen May 8, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes sense if you think of this function as being semantically equivalent to

SELECT first(x) FROM tab WHERE y = max(y)

Row(null) :: Nil
)

// structs as ordering value.
checkAnswer(
sql("select max_by(x, y) FROM VALUES (('a', (10, 20))), (('b', (10, 50))), " +
"(('c', (10, 60))) AS tab(x, y)"),
Row("c") :: Nil
)

checkAnswer(
sql("select max_by(x, y) FROM VALUES (('a', (10, 20))), (('b', (10, 50))), " +
"(('c', null)) AS tab(x, y)"),
Row("b") :: Nil
)

withTempView("tempView") {
val dfWithMap = Seq((0, "a"), (1, "b"), (2, "c"))
.toDF("x", "y")
.select($"x", map($"x", $"y").as("y"))
.createOrReplaceTempView("tempView")
val error = intercept[AnalysisException] {
sql("SELECT max_by(x, y) FROM tempView").show
}
assert(
error.message.contains("function max_by does not support ordering on type map<int,string>"))
}
}

test("min_by") {
val yearOfMinEarnings =
sql("SELECT course, min_by(year, earnings) FROM courseSales GROUP BY course")
checkAnswer(yearOfMinEarnings, Row("dotNET", 2012) :: Row("Java", 2012) :: Nil)

checkAnswer(
sql("SELECT min_by(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS tab(x, y)"),
Row("a") :: Nil
)

checkAnswer(
sql("SELECT min_by(x, y) FROM VALUES (('a', 10)), (('b', null)), (('c', 20)) AS tab(x, y)"),
Row("a") :: Nil
)

checkAnswer(
sql("SELECT min_by(x, y) FROM VALUES (('a', null)), (('b', null)), (('c', 20)) AS tab(x, y)"),
Row("c") :: Nil
)

checkAnswer(
sql("SELECT min_by(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', null)) AS tab(x, y)"),
Row("a") :: Nil
)

checkAnswer(
sql("SELECT min_by(x, y) FROM VALUES (('a', null)), (('b', null)) AS tab(x, y)"),
Row(null) :: Nil
)

// structs as ordering value.
checkAnswer(
sql("select min_by(x, y) FROM VALUES (('a', (10, 20))), (('b', (10, 50))), " +
"(('c', (10, 60))) AS tab(x, y)"),
Row("a") :: Nil
)

checkAnswer(
sql("select min_by(x, y) FROM VALUES (('a', null)), (('b', (10, 50))), " +
"(('c', (10, 60))) AS tab(x, y)"),
Row("b") :: Nil
)

withTempView("tempView") {
val dfWithMap = Seq((0, "a"), (1, "b"), (2, "c"))
.toDF("x", "y")
.select($"x", map($"x", $"y").as("y"))
.createOrReplaceTempView("tempView")
val error = intercept[AnalysisException] {
sql("SELECT min_by(x, y) FROM tempView").show
}
assert(
error.message.contains("function min_by does not support ordering on type map<int,string>"))
}
}
}