From 93c5b8a48806a4b4210ff67fcc157a2a79ab8eaa Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 10 Apr 2018 13:12:40 +0200 Subject: [PATCH 1/9] [SPARK-23917][SQL] Add array_max function --- python/pyspark/sql/functions.py | 15 +++++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../sql/catalyst/expressions/arithmetic.scala | 6 +- .../expressions/codegen/CodeGenerator.scala | 17 ++++++ .../expressions/collectionOperations.scala | 60 ++++++++++++++++++- .../CollectionExpressionsSuite.scala | 10 ++++ .../org/apache/spark/sql/functions.scala | 8 +++ .../spark/sql/DataFrameFunctionsSuite.scala | 14 +++++ 8 files changed, 125 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 1b192680f0795..b9fb7f2514998 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2080,6 +2080,21 @@ def size(col): return Column(sc._jvm.functions.size(_to_java_column(col))) +@since(2.4) +def array_max(col): + """ + Collection function: returns the maximum value of the array. + + :param col: name of column or expression + + >>> df = spark.createDataFrame([([2, 1, 3],),([None, 10, -1],)], ['data']) + >>> df.select(array_max(df.data).alias('max')).collect() + [Row(max=3), Row(max=10)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_max(_to_java_column(col))) + + @since(1.5) def sort_array(col, asc=True): """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 747016beb06e7..572e20eb0821a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -408,6 +408,7 @@ object FunctionRegistry { expression[MapValues]("map_values"), expression[Size]("size"), expression[SortArray]("sort_array"), + expression[ArrayMax]("array_max"), CreateStruct.registryEntry, // misc functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index defd6f3cd8849..0378ab4c40c94 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -676,11 +676,7 @@ case class Greatest(children: Seq[Expression]) extends Expression { val evals = evalChildren.map(eval => s""" |${eval.code} - |if (!${eval.isNull} && (${ev.isNull} || - | ${ctx.genGreater(dataType, eval.value, ev.value)})) { - | ${ev.isNull} = false; - | ${ev.value} = ${eval.value}; - |} + |${ctx.reassignIfGreater(dataType, ev, eval)} """.stripMargin ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index c9c60ef1be640..2772d822fafd6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -697,6 +697,23 @@ class CodegenContext { case _ => s"(${genComp(dataType, c1, c2)}) > 0" } + /** + * Generates code for updating `partialResult` if `item` is greater than it. + * + * @param dataType data type of the expressions + * @param partialResult `ExprCode` representing the partial result which has to be updated + * @param item `ExprCode` representing the new expression to evaluate for the result + */ + def reassignIfGreater(dataType: DataType, partialResult: ExprCode, item: ExprCode): String = { + s""" + |if (!${item.isNull} && (${partialResult.isNull} || + | ${genGreater(dataType, item.value, partialResult.value)})) { + | ${partialResult.isNull} = false; + | ${partialResult.value} = ${item.value}; + |} + """.stripMargin + } + /** * Generates code to do null safe execution, i.e. only execute the code when the input is not * null by adding null check if necessary. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 91188da8b0bd3..52d1fb26539db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -21,7 +21,7 @@ import java.util.Comparator import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData} +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils} import org.apache.spark.sql.types._ /** @@ -287,3 +287,61 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } + + +/** + * Returns the maximum value in the array. + */ +@ExpressionDescription( +usage = "_FUNC_(array) - Returns the maximum value in the array.", +examples = """ + Examples: + > SELECT _FUNC_(array(1, 20, null, 3)); + 20 + """, since = "2.4.0") +case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + override def nullable: Boolean = + child.nullable || child.dataType.asInstanceOf[ArrayType].containsNull + + override def foldable: Boolean = child.foldable + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) + + private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val childGen = child.genCode(ctx) + val javaType = CodeGenerator.javaType(dataType) + val i = ctx.freshName("i") + val item = ExprCode("", + isNull = StatementValue(s"${childGen.value}.isNullAt($i)", "boolean"), + value = StatementValue(CodeGenerator.getValue(childGen.value, dataType, i), javaType)) + ev.copy(code = + s""" + |${childGen.code} + |boolean ${ev.isNull} = true; + |$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + |if (!${childGen.isNull}) { + | for (int $i = 0; $i < ${childGen.value}.numElements(); $i ++) { + | ${ctx.reassignIfGreater(dataType, ev, item)} + | } + |} + """.stripMargin) + } + + override protected def nullSafeEval(input: Any): Any = { + var max: Any = null + input.asInstanceOf[ArrayData].foreach(dataType, (_, item) => + if (item != null && (max == null || ordering.gt(item, max))) { + max = item + } + ) + max + } + + override def dataType: DataType = child.dataType match { + case ArrayType(dt, _) => dt + case _ => throw new IllegalStateException("array_max accepts only arrays.") + } +} \ No newline at end of file diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 020687e4b3a27..a2384019533b7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -105,4 +105,14 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayContains(a3, Literal("")), null) checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) } + + test("Array max") { + checkEvaluation(ArrayMax(Literal.create(Seq(1, 10, 2), ArrayType(IntegerType))), 10) + checkEvaluation( + ArrayMax(Literal.create(Seq[String](null, "abc", ""), ArrayType(StringType))), "abc") + checkEvaluation(ArrayMax(Literal.create(Seq(null), ArrayType(LongType))), null) + checkEvaluation(ArrayMax(Literal.create(null, ArrayType(StringType))), null) + checkEvaluation( + ArrayMax(Literal.create(Seq(1.123, 0.1234, 1.121), ArrayType(DoubleType))), 1.123) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index c658f25ced053..daf407926dca4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3300,6 +3300,14 @@ object functions { */ def sort_array(e: Column, asc: Boolean): Column = withExpr { SortArray(e.expr, lit(asc).expr) } + /** + * Returns the maximum value in the array. + * + * @group collection_funcs + * @since 2.4.0 + */ + def array_max(e: Column): Column = withExpr { ArrayMax(e.expr) } + /** * Returns an unordered array containing the keys of the map. * @group collection_funcs diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 50e475984f458..5d5d92c84df6d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -413,6 +413,20 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("array_max function") { + val df = Seq( + Seq[Option[Int]](Some(1), Some(3), Some(2)), + Seq.empty[Option[Int]], + Seq[Option[Int]](None), + Seq[Option[Int]](None, Some(1), Some(-100)) + ).toDF("a") + + val answer = Seq(Row(3), Row(null), Row(null), Row(1)) + + checkAnswer(df.select(array_max(df("a"))), answer) + checkAnswer(df.selectExpr("array_max(a)"), answer) + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { From a296bc0db8b8d3befa05b7d0a8faedea4f21a625 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 10 Apr 2018 13:25:00 +0200 Subject: [PATCH 2/9] fix scalastyle --- .../spark/sql/catalyst/expressions/collectionOperations.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 52d1fb26539db..93eb57d638aa8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -344,4 +344,4 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast case ArrayType(dt, _) => dt case _ => throw new IllegalStateException("array_max accepts only arrays.") } -} \ No newline at end of file +} From c8c1d0385f9ccaa714f5f57d3e65c12bf9586447 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 10 Apr 2018 13:44:32 +0200 Subject: [PATCH 3/9] add missing space --- python/pyspark/sql/functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index b9fb7f2514998..61924b802787a 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2087,7 +2087,7 @@ def array_max(col): :param col: name of column or expression - >>> df = spark.createDataFrame([([2, 1, 3],),([None, 10, -1],)], ['data']) + >>> df = spark.createDataFrame([([2, 1, 3],), ([None, 10, -1],)], ['data']) >>> df.select(array_max(df.data).alias('max')).collect() [Row(max=3), Row(max=10)] """ From d017ccf05c9787521b4af7489b20e96c69e4b8d5 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 10 Apr 2018 15:10:27 +0200 Subject: [PATCH 4/9] add orderable type check --- .../catalyst/expressions/collectionOperations.scala | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 93eb57d638aa8..691eadd327905 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -310,6 +310,15 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) + override def checkInputDataTypes(): TypeCheckResult = { + val typeCheckResult = super.checkInputDataTypes() + if (typeCheckResult.isSuccess) { + TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName") + } else { + typeCheckResult + } + } + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val childGen = child.genCode(ctx) val javaType = CodeGenerator.javaType(dataType) @@ -342,6 +351,6 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast override def dataType: DataType = child.dataType match { case ArrayType(dt, _) => dt - case _ => throw new IllegalStateException("array_max accepts only arrays.") + case _ => throw new IllegalStateException(s"$prettyName accepts only arrays.") } } From e082f0017dc670441e96a9b7d2ffa527302db2e3 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 10 Apr 2018 15:51:16 +0200 Subject: [PATCH 5/9] fix indentation --- python/pyspark/sql/functions.py | 2 +- .../spark/sql/catalyst/expressions/collectionOperations.scala | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 61924b802787a..f3492ae42639c 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2090,7 +2090,7 @@ def array_max(col): >>> df = spark.createDataFrame([([2, 1, 3],), ([None, 10, -1],)], ['data']) >>> df.select(array_max(df.data).alias('max')).collect() [Row(max=3), Row(max=10)] - """ + """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.array_max(_to_java_column(col))) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 691eadd327905..ce9ff6b1769a1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -293,8 +293,8 @@ case class ArrayContains(left: Expression, right: Expression) * Returns the maximum value in the array. */ @ExpressionDescription( -usage = "_FUNC_(array) - Returns the maximum value in the array.", -examples = """ + usage = "_FUNC_(array) - Returns the maximum value in the array.", + examples = """ Examples: > SELECT _FUNC_(array(1, 20, null, 3)); 20 From a4fd6165227a148956368c2c8e86c99aac1267b5 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 12 Apr 2018 10:16:06 +0200 Subject: [PATCH 6/9] address comment --- .../spark/sql/catalyst/expressions/collectionOperations.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index ce9ff6b1769a1..f9c423a98a0a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -304,8 +304,6 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast override def nullable: Boolean = child.nullable || child.dataType.asInstanceOf[ArrayType].containsNull - override def foldable: Boolean = child.foldable - override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) From 8dfd2634644b3b080c82bc5a4a81fdab22147b05 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 12 Apr 2018 11:03:49 +0200 Subject: [PATCH 7/9] update after codegen changes --- .../spark/sql/catalyst/expressions/collectionOperations.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index f9c423a98a0a6..381c05ddf8501 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -322,8 +322,8 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast val javaType = CodeGenerator.javaType(dataType) val i = ctx.freshName("i") val item = ExprCode("", - isNull = StatementValue(s"${childGen.value}.isNullAt($i)", "boolean"), - value = StatementValue(CodeGenerator.getValue(childGen.value, dataType, i), javaType)) + isNull = JavaCode.isNullExpression(s"${childGen.value}.isNullAt($i)"), + value = JavaCode.expression(CodeGenerator.getValue(childGen.value, dataType, i), dataType)) ev.copy(code = s""" |${childGen.code} From e739a0a247bc3782ee4348246eff921c86f83e13 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 12 Apr 2018 16:03:36 +0200 Subject: [PATCH 8/9] fix nullable --- .../spark/sql/catalyst/expressions/collectionOperations.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 381c05ddf8501..b3306acd39813 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -301,8 +301,7 @@ case class ArrayContains(left: Expression, right: Expression) """, since = "2.4.0") case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { - override def nullable: Boolean = - child.nullable || child.dataType.asInstanceOf[ArrayType].containsNull + override def nullable: Boolean = true override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) From 1cde795fe96b915f7b322ea1746c436d51391528 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 13 Apr 2018 11:46:19 +0200 Subject: [PATCH 9/9] address comment --- .../spark/sql/catalyst/expressions/collectionOperations.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index b3306acd39813..e2614a179aad8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -293,7 +293,7 @@ case class ArrayContains(left: Expression, right: Expression) * Returns the maximum value in the array. */ @ExpressionDescription( - usage = "_FUNC_(array) - Returns the maximum value in the array.", + usage = "_FUNC_(array) - Returns the maximum value in the array. NULL elements are skipped.", examples = """ Examples: > SELECT _FUNC_(array(1, 20, null, 3)); @@ -350,4 +350,6 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast case ArrayType(dt, _) => dt case _ => throw new IllegalStateException(s"$prettyName accepts only arrays.") } + + override def prettyName: String = "array_max" }