diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 1b192680f0795..f3492ae42639c 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2080,6 +2080,21 @@ def size(col): return Column(sc._jvm.functions.size(_to_java_column(col))) +@since(2.4) +def array_max(col): + """ + Collection function: returns the maximum value of the array. + + :param col: name of column or expression + + >>> df = spark.createDataFrame([([2, 1, 3],), ([None, 10, -1],)], ['data']) + >>> df.select(array_max(df.data).alias('max')).collect() + [Row(max=3), Row(max=10)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_max(_to_java_column(col))) + + @since(1.5) def sort_array(col, asc=True): """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 747016beb06e7..572e20eb0821a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -408,6 +408,7 @@ object FunctionRegistry { expression[MapValues]("map_values"), expression[Size]("size"), expression[SortArray]("sort_array"), + expression[ArrayMax]("array_max"), CreateStruct.registryEntry, // misc functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 9212c3de1f814..942dfd4292610 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -674,11 +674,7 @@ case class Greatest(children: Seq[Expression]) extends Expression { val evals = evalChildren.map(eval => s""" |${eval.code} - |if (!${eval.isNull} && (${ev.isNull} || - | ${ctx.genGreater(dataType, eval.value, ev.value)})) { - | ${ev.isNull} = false; - | ${ev.value} = ${eval.value}; - |} + |${ctx.reassignIfGreater(dataType, ev, eval)} """.stripMargin ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 0abfc9fa4c465..c86c5beded9d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -699,6 +699,23 @@ class CodegenContext { case _ => s"(${genComp(dataType, c1, c2)}) > 0" } + /** + * Generates code for updating `partialResult` if `item` is greater than it. + * + * @param dataType data type of the expressions + * @param partialResult `ExprCode` representing the partial result which has to be updated + * @param item `ExprCode` representing the new expression to evaluate for the result + */ + def reassignIfGreater(dataType: DataType, partialResult: ExprCode, item: ExprCode): String = { + s""" + |if (!${item.isNull} && (${partialResult.isNull} || + | ${genGreater(dataType, item.value, partialResult.value)})) { + | ${partialResult.isNull} = false; + | ${partialResult.value} = ${item.value}; + |} + """.stripMargin + } + /** * Generates code to do null safe execution, i.e. only execute the code when the input is not * null by adding null check if necessary. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 91188da8b0bd3..e2614a179aad8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -21,7 +21,7 @@ import java.util.Comparator import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData} +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils} import org.apache.spark.sql.types._ /** @@ -287,3 +287,69 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } + + +/** + * Returns the maximum value in the array. + */ +@ExpressionDescription( + usage = "_FUNC_(array) - Returns the maximum value in the array. NULL elements are skipped.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 20, null, 3)); + 20 + """, since = "2.4.0") +case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + override def nullable: Boolean = true + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) + + private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) + + override def checkInputDataTypes(): TypeCheckResult = { + val typeCheckResult = super.checkInputDataTypes() + if (typeCheckResult.isSuccess) { + TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName") + } else { + typeCheckResult + } + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val childGen = child.genCode(ctx) + val javaType = CodeGenerator.javaType(dataType) + val i = ctx.freshName("i") + val item = ExprCode("", + isNull = JavaCode.isNullExpression(s"${childGen.value}.isNullAt($i)"), + value = JavaCode.expression(CodeGenerator.getValue(childGen.value, dataType, i), dataType)) + ev.copy(code = + s""" + |${childGen.code} + |boolean ${ev.isNull} = true; + |$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + |if (!${childGen.isNull}) { + | for (int $i = 0; $i < ${childGen.value}.numElements(); $i ++) { + | ${ctx.reassignIfGreater(dataType, ev, item)} + | } + |} + """.stripMargin) + } + + override protected def nullSafeEval(input: Any): Any = { + var max: Any = null + input.asInstanceOf[ArrayData].foreach(dataType, (_, item) => + if (item != null && (max == null || ordering.gt(item, max))) { + max = item + } + ) + max + } + + override def dataType: DataType = child.dataType match { + case ArrayType(dt, _) => dt + case _ => throw new IllegalStateException(s"$prettyName accepts only arrays.") + } + + override def prettyName: String = "array_max" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 020687e4b3a27..a2384019533b7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -105,4 +105,14 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayContains(a3, Literal("")), null) checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) } + + test("Array max") { + checkEvaluation(ArrayMax(Literal.create(Seq(1, 10, 2), ArrayType(IntegerType))), 10) + checkEvaluation( + ArrayMax(Literal.create(Seq[String](null, "abc", ""), ArrayType(StringType))), "abc") + checkEvaluation(ArrayMax(Literal.create(Seq(null), ArrayType(LongType))), null) + checkEvaluation(ArrayMax(Literal.create(null, ArrayType(StringType))), null) + checkEvaluation( + ArrayMax(Literal.create(Seq(1.123, 0.1234, 1.121), ArrayType(DoubleType))), 1.123) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index c658f25ced053..daf407926dca4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3300,6 +3300,14 @@ object functions { */ def sort_array(e: Column, asc: Boolean): Column = withExpr { SortArray(e.expr, lit(asc).expr) } + /** + * Returns the maximum value in the array. + * + * @group collection_funcs + * @since 2.4.0 + */ + def array_max(e: Column): Column = withExpr { ArrayMax(e.expr) } + /** * Returns an unordered array containing the keys of the map. * @group collection_funcs diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 50e475984f458..5d5d92c84df6d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -413,6 +413,20 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("array_max function") { + val df = Seq( + Seq[Option[Int]](Some(1), Some(3), Some(2)), + Seq.empty[Option[Int]], + Seq[Option[Int]](None), + Seq[Option[Int]](None, Some(1), Some(-100)) + ).toDF("a") + + val answer = Seq(Row(3), Row(null), Row(null), Row(1)) + + checkAnswer(df.select(array_max(df("a"))), answer) + checkAnswer(df.selectExpr("array_max(a)"), answer) + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) {