Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2080,6 +2080,21 @@ def size(col):
return Column(sc._jvm.functions.size(_to_java_column(col)))


@since(2.4)
def array_max(col):
"""
Collection function: returns the maximum value of the array.

:param col: name of column or expression

>>> df = spark.createDataFrame([([2, 1, 3],), ([None, 10, -1],)], ['data'])
>>> df.select(array_max(df.data).alias('max')).collect()
[Row(max=3), Row(max=10)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.array_max(_to_java_column(col)))


@since(1.5)
def sort_array(col, asc=True):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,7 @@ object FunctionRegistry {
expression[MapValues]("map_values"),
expression[Size]("size"),
expression[SortArray]("sort_array"),
expression[ArrayMax]("array_max"),
CreateStruct.registryEntry,

// misc functions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -674,11 +674,7 @@ case class Greatest(children: Seq[Expression]) extends Expression {
val evals = evalChildren.map(eval =>
s"""
|${eval.code}
|if (!${eval.isNull} && (${ev.isNull} ||
| ${ctx.genGreater(dataType, eval.value, ev.value)})) {
| ${ev.isNull} = false;
| ${ev.value} = ${eval.value};
|}
|${ctx.reassignIfGreater(dataType, ev, eval)}
""".stripMargin
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,23 @@ class CodegenContext {
case _ => s"(${genComp(dataType, c1, c2)}) > 0"
}

/**
* Generates code for updating `partialResult` if `item` is greater than it.
*
* @param dataType data type of the expressions
* @param partialResult `ExprCode` representing the partial result which has to be updated
* @param item `ExprCode` representing the new expression to evaluate for the result
*/
def reassignIfGreater(dataType: DataType, partialResult: ExprCode, item: ExprCode): String = {
s"""
|if (!${item.isNull} && (${partialResult.isNull} ||
| ${genGreater(dataType, item.value, partialResult.value)})) {
| ${partialResult.isNull} = false;
| ${partialResult.value} = ${item.value};
|}
""".stripMargin
}

/**
* Generates code to do null safe execution, i.e. only execute the code when the input is not
* null by adding null check if necessary.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.util.Comparator
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData}
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils}
import org.apache.spark.sql.types._

/**
Expand Down Expand Up @@ -287,3 +287,69 @@ case class ArrayContains(left: Expression, right: Expression)

override def prettyName: String = "array_contains"
}


/**
* Returns the maximum value in the array.
*/
@ExpressionDescription(
usage = "_FUNC_(array) - Returns the maximum value in the array. NULL elements are skipped.",
examples = """
Examples:
> SELECT _FUNC_(array(1, 20, null, 3));
20
""", since = "2.4.0")
case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {

override def nullable: Boolean = true

override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType)

private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType)

override def checkInputDataTypes(): TypeCheckResult = {
val typeCheckResult = super.checkInputDataTypes()
if (typeCheckResult.isSuccess) {
TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName")
} else {
typeCheckResult
}
}

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val childGen = child.genCode(ctx)
val javaType = CodeGenerator.javaType(dataType)
val i = ctx.freshName("i")
val item = ExprCode("",
isNull = JavaCode.isNullExpression(s"${childGen.value}.isNullAt($i)"),
value = JavaCode.expression(CodeGenerator.getValue(childGen.value, dataType, i), dataType))
ev.copy(code =
s"""
|${childGen.code}
|boolean ${ev.isNull} = true;
|$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to use MIN value for each data type instead of default value?
If we perform this operation against (-10, -100, -1000), I think that we would get -1 as a result.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvm, isNull is used for assigning the initial value.

|if (!${childGen.isNull}) {
| for (int $i = 0; $i < ${childGen.value}.numElements(); $i ++) {
| ${ctx.reassignIfGreater(dataType, ev, item)}
| }
|}
""".stripMargin)
}

override protected def nullSafeEval(input: Any): Any = {
var max: Any = null
input.asInstanceOf[ArrayData].foreach(dataType, (_, item) =>
if (item != null && (max == null || ordering.gt(item, max))) {
max = item
}
)
max
}

override def dataType: DataType = child.dataType match {
case ArrayType(dt, _) => dt
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also check if dt is orderable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the check in the checkInputDataTypes method, thanks.

case _ => throw new IllegalStateException(s"$prettyName accepts only arrays.")
}

override def prettyName: String = "array_max"
}
Original file line number Diff line number Diff line change
Expand Up @@ -105,4 +105,14 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(ArrayContains(a3, Literal("")), null)
checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null)
}

test("Array max") {
checkEvaluation(ArrayMax(Literal.create(Seq(1, 10, 2), ArrayType(IntegerType))), 10)
checkEvaluation(
ArrayMax(Literal.create(Seq[String](null, "abc", ""), ArrayType(StringType))), "abc")
checkEvaluation(ArrayMax(Literal.create(Seq(null), ArrayType(LongType))), null)
checkEvaluation(ArrayMax(Literal.create(null, ArrayType(StringType))), null)
checkEvaluation(
ArrayMax(Literal.create(Seq(1.123, 0.1234, 1.121), ArrayType(DoubleType))), 1.123)
}
}
8 changes: 8 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3300,6 +3300,14 @@ object functions {
*/
def sort_array(e: Column, asc: Boolean): Column = withExpr { SortArray(e.expr, lit(asc).expr) }

/**
* Returns the maximum value in the array.
*
* @group collection_funcs
* @since 2.4.0
*/
def array_max(e: Column): Column = withExpr { ArrayMax(e.expr) }

/**
* Returns an unordered array containing the keys of the map.
* @group collection_funcs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,20 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
)
}

test("array_max function") {
val df = Seq(
Seq[Option[Int]](Some(1), Some(3), Some(2)),
Seq.empty[Option[Int]],
Seq[Option[Int]](None),
Seq[Option[Int]](None, Some(1), Some(-100))
).toDF("a")

val answer = Seq(Row(3), Row(null), Row(null), Row(1))

checkAnswer(df.select(array_max(df("a"))), answer)
checkAnswer(df.selectExpr("array_max(a)"), answer)
}

private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = {
import DataFrameFunctionsSuite.CodegenFallbackExpr
for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) {
Expand Down