From 5c7e3c500aa46461cca1d2d802a6be4f7caa3cb4 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 8 May 2019 19:39:44 +0800 Subject: [PATCH 1/6] Add max_by. --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/aggregate/MaxBy.scala | 81 +++++++++++++++++++ .../spark/sql/DataFrameAggregateSuite.scala | 43 ++++++++++ 3 files changed, 125 insertions(+) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxBy.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index d725ef5da06e..ba75bbf09027 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -289,6 +289,7 @@ object FunctionRegistry { expression[Last]("last"), expression[Last]("last_value"), expression[Max]("max"), + expression[MaxBy]("max_by"), expression[Average]("mean"), expression[Min]("min"), expression[Percentile]("percentile"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxBy.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxBy.scala new file mode 100644 index 000000000000..86c5ccdeea10 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxBy.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.types._ + +@ExpressionDescription( + usage = "_FUNC_(x, y) - Returns the value of `x` associated with the maximum value of `y`.", + examples = """ + Examples: + > SELECT _FUNC_(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS tab(x, y); + b + """, + since = "3.0") +case class MaxBy(valueExpr: Expression, maxExpr: Expression) extends DeclarativeAggregate { + + override def children: Seq[Expression] = valueExpr :: maxExpr :: Nil + + override def nullable: Boolean = true + + // Return data type. + override def dataType: DataType = valueExpr.dataType + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForOrderingExpr(maxExpr.dataType, "function max_by") + + private lazy val max = AttributeReference("max", maxExpr.dataType)() + private lazy val value = AttributeReference("value", valueExpr.dataType)() + + override lazy val aggBufferAttributes: Seq[AttributeReference] = value :: max :: Nil + + override lazy val initialValues: Seq[Literal] = Seq( + /* value = */ Literal.create(null, valueExpr.dataType), + /* max = */ Literal.create(null, maxExpr.dataType) + ) + + override lazy val updateExpressions: Seq[Expression] = Seq( + /* value = */ + CaseWhen( + (And(IsNull(max), IsNull(maxExpr)), Literal.create(null, valueExpr.dataType)) :: + (IsNull(max), valueExpr) :: + (IsNull(maxExpr), value) :: Nil, + If(GreaterThan(max, maxExpr), value, valueExpr) + ), + /* max = */ greatest(max, maxExpr) + ) + + override lazy val mergeExpressions: Seq[Expression] = { + Seq( + /* value = */ + CaseWhen( + (And(IsNull(max.left), IsNull(max.right)), Literal.create(null, valueExpr.dataType)) :: + (IsNull(max.left), value.right) :: + (IsNull(max.right), value.left) :: Nil, + If(GreaterThan(max.left, max.right), value.left, value.right) + ), + /* max = */ greatest(max.left, max.right) + ) + } + + override lazy val evaluateExpression: AttributeReference = value +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 97aaa1b584af..e74a890a47f9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -782,4 +782,47 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { val countAndDistinct = df.select(count("*"), countDistinct("*")) checkAnswer(countAndDistinct, Row(100000, 100)) } + + test("max_by") { + val yearOfMaxEarnings = + sql("SELECT course, max_by(year, earnings) FROM courseSales GROUP BY course") + checkAnswer(yearOfMaxEarnings, Row("dotNET", 2013) :: Row("Java", 2013) :: Nil) + + checkAnswer( + sql("SELECT max_by(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS tab(x, y)"), + Row("b") :: Nil + ) + + checkAnswer( + sql("SELECT max_by(x, y) FROM VALUES (('a', 10)), (('b', null)), (('c', 20)) AS tab(x, y)"), + Row("c") :: Nil + ) + + checkAnswer( + sql("SELECT max_by(x, y) FROM VALUES (('a', null)), (('b', null)), (('c', 20)) AS tab(x, y)"), + Row("c") :: Nil + ) + + checkAnswer( + sql("SELECT max_by(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', null)) AS tab(x, y)"), + Row("b") :: Nil + ) + + checkAnswer( + sql("SELECT max_by(x, y) FROM VALUES (('a', null)), (('b', null)) AS tab(x, y)"), + Row(null) :: Nil + ) + + withTempView("tempView") { + val dfWithMap = Seq((0, "a"), (1, "b"), (2, "c")) + .toDF("x", "y") + .select($"x", map($"x", $"y").as("y")) + .createOrReplaceTempView("tempView") + val error = intercept[AnalysisException] { + sql("SELECT max_by(x, y) FROM tempView").show + } + assert( + error.message.contains("function max_by does not support ordering on type map")) + } + } } From 79f1015caed838658d64874365613d761a605dcf Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 9 May 2019 13:54:58 +0800 Subject: [PATCH 2/6] Add min_by. --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/aggregate/MaxBy.scala | 81 ------------ .../expressions/aggregate/MaxByAndMinBy.scala | 123 ++++++++++++++++++ .../spark/sql/DataFrameAggregateSuite.scala | 43 ++++++ 4 files changed, 167 insertions(+), 81 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxBy.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxByAndMinBy.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index ba75bbf09027..986662c95114 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -292,6 +292,7 @@ object FunctionRegistry { expression[MaxBy]("max_by"), expression[Average]("mean"), expression[Min]("min"), + expression[MinBy]("min_by"), expression[Percentile]("percentile"), expression[Skewness]("skewness"), expression[ApproximatePercentile]("percentile_approx"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxBy.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxBy.scala deleted file mode 100644 index 86c5ccdeea10..000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxBy.scala +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions.aggregate - -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.TypeUtils -import org.apache.spark.sql.types._ - -@ExpressionDescription( - usage = "_FUNC_(x, y) - Returns the value of `x` associated with the maximum value of `y`.", - examples = """ - Examples: - > SELECT _FUNC_(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS tab(x, y); - b - """, - since = "3.0") -case class MaxBy(valueExpr: Expression, maxExpr: Expression) extends DeclarativeAggregate { - - override def children: Seq[Expression] = valueExpr :: maxExpr :: Nil - - override def nullable: Boolean = true - - // Return data type. - override def dataType: DataType = valueExpr.dataType - - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForOrderingExpr(maxExpr.dataType, "function max_by") - - private lazy val max = AttributeReference("max", maxExpr.dataType)() - private lazy val value = AttributeReference("value", valueExpr.dataType)() - - override lazy val aggBufferAttributes: Seq[AttributeReference] = value :: max :: Nil - - override lazy val initialValues: Seq[Literal] = Seq( - /* value = */ Literal.create(null, valueExpr.dataType), - /* max = */ Literal.create(null, maxExpr.dataType) - ) - - override lazy val updateExpressions: Seq[Expression] = Seq( - /* value = */ - CaseWhen( - (And(IsNull(max), IsNull(maxExpr)), Literal.create(null, valueExpr.dataType)) :: - (IsNull(max), valueExpr) :: - (IsNull(maxExpr), value) :: Nil, - If(GreaterThan(max, maxExpr), value, valueExpr) - ), - /* max = */ greatest(max, maxExpr) - ) - - override lazy val mergeExpressions: Seq[Expression] = { - Seq( - /* value = */ - CaseWhen( - (And(IsNull(max.left), IsNull(max.right)), Literal.create(null, valueExpr.dataType)) :: - (IsNull(max.left), value.right) :: - (IsNull(max.right), value.left) :: Nil, - If(GreaterThan(max.left, max.right), value.left, value.right) - ), - /* max = */ greatest(max.left, max.right) - ) - } - - override lazy val evaluateExpression: AttributeReference = value -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxByAndMinBy.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxByAndMinBy.scala new file mode 100644 index 000000000000..3019ff466dcb --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxByAndMinBy.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.types._ + +/** + * The shared abstract superclass for `MaxBy` and `MinBy` SQL aggregate functions. + */ +abstract class MaxMinBy extends DeclarativeAggregate { + + def valueExpr: Expression + def orderingExpr: Expression + + protected def funcName: String + // The predicate compares two ordering values. + protected def predicate(oldExpr: Expression, newExpr: Expression): Expression + // The arithmetic expression returns greatest/least value of all parameters. + // Used to pick up updated ordering value. + protected def orderingUpdater(oldExpr: Expression, newExpr: Expression): Expression + + override def children: Seq[Expression] = valueExpr :: orderingExpr :: Nil + + override def nullable: Boolean = true + + // Return data type. + override def dataType: DataType = valueExpr.dataType + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForOrderingExpr(orderingExpr.dataType, s"function $funcName") + + private lazy val ordering = AttributeReference("ordering", orderingExpr.dataType)() + private lazy val value = AttributeReference("value", valueExpr.dataType)() + + override lazy val aggBufferAttributes: Seq[AttributeReference] = value :: ordering :: Nil + + private lazy val nullValue = Literal.create(null, valueExpr.dataType) + private lazy val nullOrdering = Literal.create(null, orderingExpr.dataType) + + override lazy val initialValues: Seq[Literal] = Seq( + /* value = */ nullValue, + /* ordering = */ nullOrdering + ) + + override lazy val updateExpressions: Seq[Expression] = Seq( + /* value = */ + CaseWhen( + (And(IsNull(ordering), IsNull(orderingExpr)), nullValue) :: + (IsNull(ordering), valueExpr) :: + (IsNull(orderingExpr), value) :: Nil, + If(predicate(ordering, orderingExpr), value, valueExpr) + ), + /* ordering = */ orderingUpdater(ordering, orderingExpr) + ) + + override lazy val mergeExpressions: Seq[Expression] = Seq( + /* value = */ + CaseWhen( + (And(IsNull(ordering.left), IsNull(ordering.right)), nullValue) :: + (IsNull(ordering.left), value.right) :: + (IsNull(ordering.right), value.left) :: Nil, + If(predicate(ordering.left, ordering.right), value.left, value.right) + ), + /* ordering = */ orderingUpdater(ordering.left, ordering.right) + ) + + override lazy val evaluateExpression: AttributeReference = value +} + +@ExpressionDescription( + usage = "_FUNC_(x, y) - Returns the value of `x` associated with the maximum value of `y`.", + examples = """ + Examples: + > SELECT _FUNC_(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS tab(x, y); + b + """, + since = "3.0") +case class MaxBy(valueExpr: Expression, orderingExpr: Expression) extends MaxMinBy { + override protected def funcName: String = "max_by" + + override protected def predicate(oldExpr: Expression, newExpr: Expression): Expression = + GreaterThan(oldExpr, newExpr) + + override protected def orderingUpdater(oldExpr: Expression, newExpr: Expression): Expression = + greatest(oldExpr, newExpr) +} + +@ExpressionDescription( + usage = "_FUNC_(x, y) - Returns the value of `x` associated with the minimum value of `y`.", + examples = """ + Examples: + > SELECT _FUNC_(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS tab(x, y); + a + """, + since = "3.0") +case class MinBy(valueExpr: Expression, orderingExpr: Expression) extends MaxMinBy { + override protected def funcName: String = "min_by" + + override protected def predicate(oldExpr: Expression, newExpr: Expression): Expression = + LessThan(oldExpr, newExpr) + + override protected def orderingUpdater(oldExpr: Expression, newExpr: Expression): Expression = + least(oldExpr, newExpr) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index e74a890a47f9..922fec6e4e7c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -825,4 +825,47 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { error.message.contains("function max_by does not support ordering on type map")) } } + + test("min_by") { + val yearOfMaxEarnings = + sql("SELECT course, min_by(year, earnings) FROM courseSales GROUP BY course") + checkAnswer(yearOfMaxEarnings, Row("dotNET", 2012) :: Row("Java", 2012) :: Nil) + + checkAnswer( + sql("SELECT min_by(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS tab(x, y)"), + Row("a") :: Nil + ) + + checkAnswer( + sql("SELECT min_by(x, y) FROM VALUES (('a', 10)), (('b', null)), (('c', 20)) AS tab(x, y)"), + Row("a") :: Nil + ) + + checkAnswer( + sql("SELECT min_by(x, y) FROM VALUES (('a', null)), (('b', null)), (('c', 20)) AS tab(x, y)"), + Row("c") :: Nil + ) + + checkAnswer( + sql("SELECT min_by(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', null)) AS tab(x, y)"), + Row("a") :: Nil + ) + + checkAnswer( + sql("SELECT min_by(x, y) FROM VALUES (('a', null)), (('b', null)) AS tab(x, y)"), + Row(null) :: Nil + ) + + withTempView("tempView") { + val dfWithMap = Seq((0, "a"), (1, "b"), (2, "c")) + .toDF("x", "y") + .select($"x", map($"x", $"y").as("y")) + .createOrReplaceTempView("tempView") + val error = intercept[AnalysisException] { + sql("SELECT min_by(x, y) FROM tempView").show + } + assert( + error.message.contains("function min_by does not support ordering on type map")) + } + } } From 798f0fa1bf8b51819a34a36239fa3e29db3ff53b Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 10 May 2019 00:03:39 +0800 Subject: [PATCH 3/6] Add test. --- .../spark/sql/DataFrameAggregateSuite.scala | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 922fec6e4e7c..489d989e3b3a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -813,6 +813,19 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Row(null) :: Nil ) + // structs as ordering value. + checkAnswer( + sql("select max_by(x, y) FROM VALUES (('a', (10, 20))), (('b', (10, 50))), " + + "(('c', (10, 60))) AS tab(x, y)"), + Row("c") :: Nil + ) + + checkAnswer( + sql("select max_by(x, y) FROM VALUES (('a', (10, 20))), (('b', (10, 50))), " + + "(('c', null)) AS tab(x, y)"), + Row("b") :: Nil + ) + withTempView("tempView") { val dfWithMap = Seq((0, "a"), (1, "b"), (2, "c")) .toDF("x", "y") @@ -856,6 +869,19 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Row(null) :: Nil ) + // structs as ordering value. + checkAnswer( + sql("select min_by(x, y) FROM VALUES (('a', (10, 20))), (('b', (10, 50))), " + + "(('c', (10, 60))) AS tab(x, y)"), + Row("a") :: Nil + ) + + checkAnswer( + sql("select min_by(x, y) FROM VALUES (('a', null)), (('b', (10, 50))), " + + "(('c', (10, 60))) AS tab(x, y)"), + Row("b") :: Nil + ) + withTempView("tempView") { val dfWithMap = Seq((0, "a"), (1, "b"), (2, "c")) .toDF("x", "y") From 26b5a32e0a7f38fd575e6479a344971e4b82bc5b Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 11 May 2019 09:53:45 +0800 Subject: [PATCH 4/6] Using DSL. --- .../expressions/aggregate/MaxByAndMinBy.scala | 12 ++++++------ .../apache/spark/sql/DataFrameAggregateSuite.scala | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxByAndMinBy.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxByAndMinBy.scala index 3019ff466dcb..7cd34f07d633 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxByAndMinBy.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxByAndMinBy.scala @@ -64,9 +64,9 @@ abstract class MaxMinBy extends DeclarativeAggregate { override lazy val updateExpressions: Seq[Expression] = Seq( /* value = */ CaseWhen( - (And(IsNull(ordering), IsNull(orderingExpr)), nullValue) :: - (IsNull(ordering), valueExpr) :: - (IsNull(orderingExpr), value) :: Nil, + (ordering.isNull && orderingExpr.isNull, nullValue) :: + (ordering.isNull, valueExpr) :: + (orderingExpr.isNull, value) :: Nil, If(predicate(ordering, orderingExpr), value, valueExpr) ), /* ordering = */ orderingUpdater(ordering, orderingExpr) @@ -75,9 +75,9 @@ abstract class MaxMinBy extends DeclarativeAggregate { override lazy val mergeExpressions: Seq[Expression] = Seq( /* value = */ CaseWhen( - (And(IsNull(ordering.left), IsNull(ordering.right)), nullValue) :: - (IsNull(ordering.left), value.right) :: - (IsNull(ordering.right), value.left) :: Nil, + (ordering.left.isNull && ordering.right.isNull, nullValue) :: + (ordering.left.isNull, value.right) :: + (ordering.right.isNull, value.left) :: Nil, If(predicate(ordering.left, ordering.right), value.left, value.right) ), /* ordering = */ orderingUpdater(ordering.left, ordering.right) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 489d989e3b3a..d89ecc22a7c0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -840,9 +840,9 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { } test("min_by") { - val yearOfMaxEarnings = + val yearOfMinEarnings = sql("SELECT course, min_by(year, earnings) FROM courseSales GROUP BY course") - checkAnswer(yearOfMaxEarnings, Row("dotNET", 2012) :: Row("Java", 2012) :: Nil) + checkAnswer(yearOfMinEarnings, Row("dotNET", 2012) :: Row("Java", 2012) :: Nil) checkAnswer( sql("SELECT min_by(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS tab(x, y)"), From dd1d9de86631a18166c2e0d955247a07605c219a Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 13 May 2019 16:28:34 +0800 Subject: [PATCH 5/6] Address comments. --- .../expressions/aggregate/MaxByAndMinBy.scala | 44 ++++++++++--------- 1 file changed, 24 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxByAndMinBy.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxByAndMinBy.scala index 7cd34f07d633..e5b673ae5798 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxByAndMinBy.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxByAndMinBy.scala @@ -48,42 +48,46 @@ abstract class MaxMinBy extends DeclarativeAggregate { override def checkInputDataTypes(): TypeCheckResult = TypeUtils.checkForOrderingExpr(orderingExpr.dataType, s"function $funcName") - private lazy val ordering = AttributeReference("ordering", orderingExpr.dataType)() - private lazy val value = AttributeReference("value", valueExpr.dataType)() + // The attributes used to keep extremum (max or min) and associated aggregated values. + private lazy val maxOrdering = AttributeReference("maxOrdering", orderingExpr.dataType)() + private lazy val valueWithMaxOrdering = + AttributeReference("valueWithMaxOrdering", valueExpr.dataType)() - override lazy val aggBufferAttributes: Seq[AttributeReference] = value :: ordering :: Nil + override lazy val aggBufferAttributes: Seq[AttributeReference] = + valueWithMaxOrdering :: maxOrdering :: Nil private lazy val nullValue = Literal.create(null, valueExpr.dataType) private lazy val nullOrdering = Literal.create(null, orderingExpr.dataType) override lazy val initialValues: Seq[Literal] = Seq( - /* value = */ nullValue, - /* ordering = */ nullOrdering + /* valueWithMaxOrdering = */ nullValue, + /* maxOrdering = */ nullOrdering ) override lazy val updateExpressions: Seq[Expression] = Seq( - /* value = */ + /* valueWithMaxOrdering = */ CaseWhen( - (ordering.isNull && orderingExpr.isNull, nullValue) :: - (ordering.isNull, valueExpr) :: - (orderingExpr.isNull, value) :: Nil, - If(predicate(ordering, orderingExpr), value, valueExpr) + (maxOrdering.isNull && orderingExpr.isNull, nullValue) :: + (maxOrdering.isNull, valueExpr) :: + (orderingExpr.isNull, valueWithMaxOrdering) :: Nil, + If(predicate(maxOrdering, orderingExpr), valueWithMaxOrdering, valueExpr) ), - /* ordering = */ orderingUpdater(ordering, orderingExpr) + /* maxOrdering = */ orderingUpdater(maxOrdering, orderingExpr) ) override lazy val mergeExpressions: Seq[Expression] = Seq( - /* value = */ + /* valueWithMaxOrdering = */ CaseWhen( - (ordering.left.isNull && ordering.right.isNull, nullValue) :: - (ordering.left.isNull, value.right) :: - (ordering.right.isNull, value.left) :: Nil, - If(predicate(ordering.left, ordering.right), value.left, value.right) + (maxOrdering.left.isNull && maxOrdering.right.isNull, nullValue) :: + (maxOrdering.left.isNull, valueWithMaxOrdering.right) :: + (maxOrdering.right.isNull, valueWithMaxOrdering.left) :: Nil, + If(predicate(maxOrdering.left, maxOrdering.right), + valueWithMaxOrdering.left, valueWithMaxOrdering.right) ), - /* ordering = */ orderingUpdater(ordering.left, ordering.right) + /* maxOrdering = */ orderingUpdater(maxOrdering.left, maxOrdering.right) ) - override lazy val evaluateExpression: AttributeReference = value + override lazy val evaluateExpression: AttributeReference = valueWithMaxOrdering } @ExpressionDescription( @@ -98,7 +102,7 @@ case class MaxBy(valueExpr: Expression, orderingExpr: Expression) extends MaxMin override protected def funcName: String = "max_by" override protected def predicate(oldExpr: Expression, newExpr: Expression): Expression = - GreaterThan(oldExpr, newExpr) + oldExpr > newExpr override protected def orderingUpdater(oldExpr: Expression, newExpr: Expression): Expression = greatest(oldExpr, newExpr) @@ -116,7 +120,7 @@ case class MinBy(valueExpr: Expression, orderingExpr: Expression) extends MaxMin override protected def funcName: String = "min_by" override protected def predicate(oldExpr: Expression, newExpr: Expression): Expression = - LessThan(oldExpr, newExpr) + oldExpr < newExpr override protected def orderingUpdater(oldExpr: Expression, newExpr: Expression): Expression = least(oldExpr, newExpr) From 05f1767dc49e6caebd91abf207014bcd8029b4f1 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 13 May 2019 16:39:19 +0800 Subject: [PATCH 6/6] Renaming variables. --- .../expressions/aggregate/MaxByAndMinBy.scala | 41 ++++++++++--------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxByAndMinBy.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxByAndMinBy.scala index e5b673ae5798..c7fdb15130c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxByAndMinBy.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxByAndMinBy.scala @@ -49,45 +49,46 @@ abstract class MaxMinBy extends DeclarativeAggregate { TypeUtils.checkForOrderingExpr(orderingExpr.dataType, s"function $funcName") // The attributes used to keep extremum (max or min) and associated aggregated values. - private lazy val maxOrdering = AttributeReference("maxOrdering", orderingExpr.dataType)() - private lazy val valueWithMaxOrdering = - AttributeReference("valueWithMaxOrdering", valueExpr.dataType)() + private lazy val extremumOrdering = + AttributeReference("extremumOrdering", orderingExpr.dataType)() + private lazy val valueWithExtremumOrdering = + AttributeReference("valueWithExtremumOrdering", valueExpr.dataType)() override lazy val aggBufferAttributes: Seq[AttributeReference] = - valueWithMaxOrdering :: maxOrdering :: Nil + valueWithExtremumOrdering :: extremumOrdering :: Nil private lazy val nullValue = Literal.create(null, valueExpr.dataType) private lazy val nullOrdering = Literal.create(null, orderingExpr.dataType) override lazy val initialValues: Seq[Literal] = Seq( - /* valueWithMaxOrdering = */ nullValue, - /* maxOrdering = */ nullOrdering + /* valueWithExtremumOrdering = */ nullValue, + /* extremumOrdering = */ nullOrdering ) override lazy val updateExpressions: Seq[Expression] = Seq( - /* valueWithMaxOrdering = */ + /* valueWithExtremumOrdering = */ CaseWhen( - (maxOrdering.isNull && orderingExpr.isNull, nullValue) :: - (maxOrdering.isNull, valueExpr) :: - (orderingExpr.isNull, valueWithMaxOrdering) :: Nil, - If(predicate(maxOrdering, orderingExpr), valueWithMaxOrdering, valueExpr) + (extremumOrdering.isNull && orderingExpr.isNull, nullValue) :: + (extremumOrdering.isNull, valueExpr) :: + (orderingExpr.isNull, valueWithExtremumOrdering) :: Nil, + If(predicate(extremumOrdering, orderingExpr), valueWithExtremumOrdering, valueExpr) ), - /* maxOrdering = */ orderingUpdater(maxOrdering, orderingExpr) + /* extremumOrdering = */ orderingUpdater(extremumOrdering, orderingExpr) ) override lazy val mergeExpressions: Seq[Expression] = Seq( - /* valueWithMaxOrdering = */ + /* valueWithExtremumOrdering = */ CaseWhen( - (maxOrdering.left.isNull && maxOrdering.right.isNull, nullValue) :: - (maxOrdering.left.isNull, valueWithMaxOrdering.right) :: - (maxOrdering.right.isNull, valueWithMaxOrdering.left) :: Nil, - If(predicate(maxOrdering.left, maxOrdering.right), - valueWithMaxOrdering.left, valueWithMaxOrdering.right) + (extremumOrdering.left.isNull && extremumOrdering.right.isNull, nullValue) :: + (extremumOrdering.left.isNull, valueWithExtremumOrdering.right) :: + (extremumOrdering.right.isNull, valueWithExtremumOrdering.left) :: Nil, + If(predicate(extremumOrdering.left, extremumOrdering.right), + valueWithExtremumOrdering.left, valueWithExtremumOrdering.right) ), - /* maxOrdering = */ orderingUpdater(maxOrdering.left, maxOrdering.right) + /* extremumOrdering = */ orderingUpdater(extremumOrdering.left, extremumOrdering.right) ) - override lazy val evaluateExpression: AttributeReference = valueWithMaxOrdering + override lazy val evaluateExpression: AttributeReference = valueWithExtremumOrdering } @ExpressionDescription(