Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
78f960b
Adds a array_prepend expression to catalyst
navinvishy Jan 28, 2023
f2d4f68
Fix null handling
navinvishy Jan 28, 2023
85a8b4c
Fix
navinvishy Jan 28, 2023
d7b601c
Fix
navinvishy Jan 28, 2023
c703f85
Merge branch 'array-prepend' of https://github.com/navinvishy/spark i…
navinvishy Jan 28, 2023
d827d21
Lint
navinvishy Feb 8, 2023
4e7aa1e
Lint
navinvishy Feb 8, 2023
89e8917
Merge branch 'apache:master' into array-prepend
navinvishy Feb 8, 2023
ff8c19e
Add examples of usage and fix test
navinvishy Feb 9, 2023
4f3a968
Fix tests
navinvishy Feb 10, 2023
ec9ea76
Fix types
navinvishy Feb 11, 2023
c728581
Fix tests
navinvishy Feb 27, 2023
6eba188
Fix python linter
navinvishy Feb 27, 2023
63ff6cc
Add test for null cases
navinvishy Feb 27, 2023
73a7dd7
Fix type of array
navinvishy Feb 27, 2023
6f97761
Adds a array_prepend expression to catalyst
navinvishy Jan 28, 2023
fec08e9
Fix null handling
navinvishy Jan 28, 2023
a8da345
Fix
navinvishy Jan 28, 2023
3af8fd9
Fix
navinvishy Jan 28, 2023
2cd3e18
Lint
navinvishy Feb 8, 2023
a307cb4
Lint
navinvishy Feb 8, 2023
7b24500
Add examples of usage and fix test
navinvishy Feb 9, 2023
f299962
Fix tests
navinvishy Feb 10, 2023
7ce00b8
Fix types
navinvishy Feb 11, 2023
505b8e2
Fix tests
navinvishy Feb 27, 2023
216ca4c
Fix python linter
navinvishy Feb 27, 2023
c121168
Add test for null cases
navinvishy Feb 27, 2023
ec503f9
Fix type of array
navinvishy Feb 27, 2023
86a948f
Merge branch 'array-prepend' of https://github.com/navinvishy/spark i…
navinvishy Feb 27, 2023
f1c0186
Address comments
navinvishy Mar 1, 2023
34cb724
Update version
navinvishy Mar 13, 2023
baa6cc7
Address review comments
navinvishy Mar 16, 2023
8aa8ae5
Adds a array_prepend expression to catalyst
navinvishy Jan 28, 2023
1ba91c7
Fix null handling
navinvishy Jan 28, 2023
90c0c28
Fix
navinvishy Jan 28, 2023
0a69172
Fix
navinvishy Jan 28, 2023
db59880
Lint
navinvishy Feb 8, 2023
f0d9329
Lint
navinvishy Feb 8, 2023
ae5b65e
Add examples of usage and fix test
navinvishy Feb 9, 2023
af3ee0a
Fix tests
navinvishy Feb 10, 2023
3265717
Fix types
navinvishy Feb 11, 2023
7df63ea
Fix tests
navinvishy Feb 27, 2023
b4fbbd5
Fix python linter
navinvishy Feb 27, 2023
413af39
Add test for null cases
navinvishy Feb 27, 2023
9992e33
Fix type of array
navinvishy Feb 27, 2023
684a7d9
Adds a array_prepend expression to catalyst
navinvishy Jan 28, 2023
93f1819
Fix null handling
navinvishy Jan 28, 2023
a8db7b3
Fix
navinvishy Jan 28, 2023
82186b9
Fix
navinvishy Jan 28, 2023
30988b7
Lint
navinvishy Feb 8, 2023
09a61ca
Lint
navinvishy Feb 8, 2023
d188279
Add examples of usage and fix test
navinvishy Feb 9, 2023
15b713d
Fix tests
navinvishy Feb 10, 2023
380b156
Fix types
navinvishy Feb 11, 2023
4ecfac8
Fix tests
navinvishy Feb 27, 2023
160db20
Fix python linter
navinvishy Feb 27, 2023
3aa673f
Address comments
navinvishy Mar 1, 2023
8b480a5
Update version
navinvishy Mar 13, 2023
19fe924
Address review comments
navinvishy Mar 16, 2023
b1cf31a
Adds a array_prepend expression to catalyst
navinvishy Jan 28, 2023
c0c6a51
Fix null handling
navinvishy Jan 28, 2023
422f393
Fix
navinvishy Jan 28, 2023
2e193e4
Lint
navinvishy Feb 8, 2023
95673b8
Add examples of usage and fix test
navinvishy Feb 9, 2023
46e6dd7
Fix tests
navinvishy Feb 10, 2023
67a64da
Fix types
navinvishy Feb 11, 2023
19505ff
Fix tests
navinvishy Feb 27, 2023
52078ff
Fix python linter
navinvishy Feb 27, 2023
e0244ad
Merge branch 'array-prepend' of https://github.com/navinvishy/spark i…
navinvishy Mar 17, 2023
8cd56bd
Fix merge
navinvishy Mar 17, 2023
6634737
Fix MiMa
navinvishy Mar 17, 2023
4dffbf7
Fix indent
navinvishy Mar 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ object CheckConnectJvmClientCompatibility {
ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.broadcast"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.typedlit"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.typedLit"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.array_prepend"),

// RelationalGroupedDataset
ProblemFilters.exclude[Problem]("org.apache.spark.sql.RelationalGroupedDataset.apply"),
Expand Down
1 change: 1 addition & 0 deletions python/docs/source/reference/pyspark.sql/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ Collection Functions
array_sort
array_insert
array_remove
array_prepend
array_distinct
array_intersect
array_union
Expand Down
30 changes: 30 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7631,6 +7631,36 @@ def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column:
return _invoke_function_over_columns("get", col, index)


@try_remote_functions
def array_prepend(col: "ColumnOrName", value: Any) -> Column:
"""
Collection function: Returns an array containing element as
well as all elements from array. The new element is positioned
at the beginning of the array.

.. versionadded:: 3.5.0

Parameters
----------
col : :class:`~pyspark.sql.Column` or str
name of column containing array
value :
a literal value, or a :class:`~pyspark.sql.Column` expression.

Returns
-------
:class:`~pyspark.sql.Column`
an array excluding given value.

Examples
--------
>>> df = spark.createDataFrame([([2, 3, 4],), ([],)], ['data'])
>>> df.select(array_prepend(df.data, 1)).collect()
[Row(array_prepend(data, 1)=[1, 2, 3, 4]), Row(array_prepend(data, 1)=[1])]
"""
return _invoke_function_over_columns("array_prepend", col, lit(value))


@try_remote_functions
def array_remove(col: "ColumnOrName", element: Any) -> Column:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,7 @@ object FunctionRegistry {
expression[Sequence]("sequence"),
expression[ArrayRepeat]("array_repeat"),
expression[ArrayRemove]("array_remove"),
expression[ArrayPrepend]("array_prepend"),
expression[ArrayDistinct]("array_distinct"),
expression[ArrayTransform]("transform"),
expression[MapFilter]("map_filter"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1399,6 +1399,152 @@ case class ArrayContains(left: Expression, right: Expression)
copy(left = newLeft, right = newRight)
}

// scalastyle:off line.size.limit
@ExpressionDescription(
usage = """
_FUNC_(array, element) - Add the element at the beginning of the array passed as first
argument. Type of element should be the same as the type of the elements of the array.
Null element is also prepended to the array. But if the array passed is NULL
output is NULL
""",
examples = """
Examples:
> SELECT _FUNC_(array('b', 'd', 'c', 'a'), 'd');
["d","b","d","c","a"]
> SELECT _FUNC_(array(1, 2, 3, null), null);
[null,1,2,3,null]
> SELECT _FUNC_(CAST(null as Array<Int>), 2);
NULL
""",
group = "array_funcs",
since = "3.5.0")
case class ArrayPrepend(left: Expression, right: Expression)
extends BinaryExpression
with ImplicitCastInputTypes
with ComplexTypeMergingExpression
with QueryErrorsBase {

override def nullable: Boolean = left.nullable

@transient protected lazy val elementType: DataType =
inputTypes.head.asInstanceOf[ArrayType].elementType

override def eval(input: InternalRow): Any = {
val value1 = left.eval(input)
if (value1 == null) {
null
} else {
val value2 = right.eval(input)
nullSafeEval(value1, value2)
}
}
override def nullSafeEval(arr: Any, elementData: Any): Any = {
val arrayData = arr.asInstanceOf[ArrayData]
val numberOfElements = arrayData.numElements() + 1
if (numberOfElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
throw QueryExecutionErrors.concatArraysWithElementsExceedLimitError(numberOfElements)
}
val finalData = new Array[Any](numberOfElements)
finalData.update(0, elementData)
arrayData.foreach(elementType, (i: Int, v: Any) => finalData.update(i + 1, v))
new GenericArrayData(finalData)
}
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val leftGen = left.genCode(ctx)
val rightGen = right.genCode(ctx)
val f = (arr: String, value: String) => {
val newArraySize = s"$arr.numElements() + 1"
val newArray = ctx.freshName("newArray")
val i = ctx.freshName("i")
val iPlus1 = s"$i+1"
val zero = "0"
val allocation = CodeGenerator.createArrayData(
newArray,
elementType,
newArraySize,
s" $prettyName failed.")
val assignment =
CodeGenerator.createArrayAssignment(newArray, elementType, arr, iPlus1, i, false)
val newElemAssignment =
CodeGenerator.setArrayElement(newArray, elementType, zero, value, Some(rightGen.isNull))
s"""
|$allocation
|$newElemAssignment
|for (int $i = 0; $i < $arr.numElements(); $i ++) {
| $assignment
|}
|${ev.value} = $newArray;
|""".stripMargin
}
val resultCode = f(leftGen.value, rightGen.value)
if(nullable) {
val nullSafeEval = leftGen.code + rightGen.code + ctx.nullSafeExec(nullable, leftGen.isNull) {
s"""
|${ev.isNull} = false;
|${resultCode}
|""".stripMargin
}
ev.copy(code =
code"""
|boolean ${ev.isNull} = true;
|${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|$nullSafeEval
""".stripMargin
)
} else {
ev.copy(code =
code"""
|${leftGen.code}
|${rightGen.code}
|${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|$resultCode
""".stripMargin, isNull = FalseLiteral)
}
}

override def prettyName: String = "array_prepend"

override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): ArrayPrepend =
copy(left = newLeft, right = newRight)

override def dataType: DataType = if (right.nullable) left.dataType.asNullable else left.dataType

override def checkInputDataTypes(): TypeCheckResult = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can refer to ArrayRemove#checkInputDataTypes here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did what ArrayContains does. Maybe we should consolidate this, since it makes sense for many of these to do the same thing? eg. ArrayContains, ArrayRemove, ArrayPrepend, ArrayAppend etc.

(left.dataType, right.dataType) match {
case (ArrayType(e1, _), e2) if e1.sameType(e2) => TypeCheckResult.TypeCheckSuccess
case (ArrayType(e1, _), e2) => DataTypeMismatch(
errorSubClass = "ARRAY_FUNCTION_DIFF_TYPES",
messageParameters = Map(
"functionName" -> toSQLId(prettyName),
"leftType" -> toSQLType(left.dataType),
"rightType" -> toSQLType(right.dataType),
"dataType" -> toSQLType(ArrayType)
))
case _ =>
DataTypeMismatch(
errorSubClass = "UNEXPECTED_INPUT_TYPE",
messageParameters = Map(
"paramIndex" -> "0",
"requiredType" -> toSQLType(ArrayType),
"inputSql" -> toSQLExpr(left),
"inputType" -> toSQLType(left.dataType)
)
)
}
}
override def inputTypes: Seq[AbstractDataType] = {
(left.dataType, right.dataType) match {
case (ArrayType(e1, hasNull), e2) =>
TypeCoercion.findTightestCommonType(e1, e2) match {
case Some(dt) => Seq(ArrayType(dt, hasNull), dt)
case _ => Seq.empty
}
case _ => Seq.empty
}
}
}

/**
* Checks if the two arrays contain at least one common element.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1855,6 +1855,50 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(ArrayRepeat(Literal("hi"), Literal(null, IntegerType)), null)
}

test("SPARK-41233: ArrayPrepend") {
val a0 = Literal.create(Seq(1, 2, 3, 4), ArrayType(IntegerType))
val a1 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType))
val a2 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType))
val a3 = Literal.create(null, ArrayType(StringType))

checkEvaluation(ArrayPrepend(a0, Literal(0)), Seq(0, 1, 2, 3, 4))
checkEvaluation(ArrayPrepend(a1, Literal("a")), Seq("a", "a", "b", "c"))
checkEvaluation(ArrayPrepend(a2, Literal(1)), Seq(1))
checkEvaluation(ArrayPrepend(a2, Literal(null, IntegerType)), Seq(null))
checkEvaluation(ArrayPrepend(a3, Literal("a")), null)
checkEvaluation(ArrayPrepend(a3, Literal(null, StringType)), null)

// complex data types
val data = Seq[Array[Byte]](
Array[Byte](5, 6),
Array[Byte](1, 2),
Array[Byte](1, 2),
Array[Byte](5, 6))
val b0 = Literal.create(
data,
ArrayType(BinaryType))
val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), null), ArrayType(BinaryType))
val nullBinary = Literal.create(null, BinaryType)
// Calling ArrayPrepend with a null element should result in NULL being prepended to the array
val dataWithNullPrepended = null +: data
checkEvaluation(ArrayPrepend(b0, nullBinary), dataWithNullPrepended)
val dataToPrepend1 = Literal.create(Array[Byte](5, 6), BinaryType)
checkEvaluation(
ArrayPrepend(b1, dataToPrepend1),
Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](2, 1), null))

val c0 = Literal.create(
Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)),
ArrayType(ArrayType(IntegerType)))
val dataToPrepend2 = Literal.create(Seq[Int](5, 6), ArrayType(IntegerType))
checkEvaluation(
ArrayPrepend(c0, dataToPrepend2),
Seq(Seq[Int](5, 6), Seq[Int](1, 2), Seq[Int](3, 4)))
checkEvaluation(
ArrayPrepend(c0, Literal.create(Seq.empty[Int], ArrayType(IntegerType))),
Seq(Seq.empty[Int], Seq[Int](1, 2), Seq[Int](3, 4)))
}

test("Array remove") {
val a0 = Literal.create(Seq(1, 2, 3, 2, 2, 5), ArrayType(IntegerType))
val a1 = Literal.create(Seq("b", "a", "a", "c", "b"), ArrayType(StringType))
Expand Down
11 changes: 11 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4044,6 +4044,17 @@ object functions {
ArrayCompact(column.expr)
}

/**
* Returns an array containing value as well as all elements from array. The new element is
* positioned at the beginning of the array.
*
* @group collection_funcs
* @since 3.5.0
*/
def array_prepend(column: Column, element: Any): Column = withExpr {
ArrayPrepend(column.expr, lit(element).expr)
}

/**
* Removes duplicate values from the array.
* @group collection_funcs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
| org.apache.spark.sql.catalyst.expressions.ArrayMax | array_max | SELECT array_max(array(1, 20, null, 3)) | struct<array_max(array(1, 20, NULL, 3)):int> |
| org.apache.spark.sql.catalyst.expressions.ArrayMin | array_min | SELECT array_min(array(1, 20, null, 3)) | struct<array_min(array(1, 20, NULL, 3)):int> |
| org.apache.spark.sql.catalyst.expressions.ArrayPosition | array_position | SELECT array_position(array(3, 2, 1), 1) | struct<array_position(array(3, 2, 1), 1):bigint> |
| org.apache.spark.sql.catalyst.expressions.ArrayPrepend | array_prepend | SELECT array_prepend(array('b', 'd', 'c', 'a'), 'd') | struct<array_prepend(array(b, d, c, a), d):array<string>> |
| org.apache.spark.sql.catalyst.expressions.ArrayRemove | array_remove | SELECT array_remove(array(1, 2, 3, null, 3), 3) | struct<array_remove(array(1, 2, 3, NULL, 3), 3):array<int>> |
| org.apache.spark.sql.catalyst.expressions.ArrayRepeat | array_repeat | SELECT array_repeat('123', 2) | struct<array_repeat(123, 2):array<string>> |
| org.apache.spark.sql.catalyst.expressions.ArraySize | array_size | SELECT array_size(array('b', 'd', 'c', 'a')) | struct<array_size(array(b, d, c, a)):int> |
Expand Down Expand Up @@ -421,4 +422,4 @@
| org.apache.spark.sql.catalyst.expressions.xml.XPathList | xpath | SELECT xpath('<a><b>b1</b><b>b2</b><b>b3</b><c>c1</c><c>c2</c></a>','a/b/text()') | struct<xpath(<a><b>b1</b><b>b2</b><b>b3</b><c>c1</c><c>c2</c></a>, a/b/text()):array<string>> |
| org.apache.spark.sql.catalyst.expressions.xml.XPathLong | xpath_long | SELECT xpath_long('<a><b>1</b><b>2</b></a>', 'sum(a/b)') | struct<xpath_long(<a><b>1</b><b>2</b></a>, sum(a/b)):bigint> |
| org.apache.spark.sql.catalyst.expressions.xml.XPathShort | xpath_short | SELECT xpath_short('<a><b>1</b><b>2</b></a>', 'sum(a/b)') | struct<xpath_short(<a><b>1</b><b>2</b></a>, sum(a/b)):smallint> |
| org.apache.spark.sql.catalyst.expressions.xml.XPathString | xpath_string | SELECT xpath_string('<a><b>b</b><c>cc</c></a>','a/c') | struct<xpath_string(<a><b>b</b><c>cc</c></a>, a/c):string> |
| org.apache.spark.sql.catalyst.expressions.xml.XPathString | xpath_string | SELECT xpath_string('<a><b>b</b><c>cc</c></a>','a/c') | struct<xpath_string(<a><b>b</b><c>cc</c></a>, a/c):string> |
11 changes: 11 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/array.sql
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,14 @@ select array_append(CAST(null AS ARRAY<String>), CAST(null as String));
select array_append(array(), 1);
select array_append(CAST(array() AS ARRAY<String>), CAST(NULL AS String));
select array_append(array(CAST(NULL AS String)), CAST(NULL AS String));

-- function array_prepend
select array_prepend(array(1, 2, 3), 4);
select array_prepend(array('a', 'b', 'c'), 'd');
select array_prepend(array(1, 2, 3, NULL), NULL);
select array_prepend(array('a', 'b', 'c', NULL), NULL);
select array_prepend(CAST(null AS ARRAY<String>), 'a');
select array_prepend(CAST(null AS ARRAY<String>), CAST(null as String));
select array_prepend(array(), 1);
select array_prepend(CAST(array() AS ARRAY<String>), CAST(NULL AS String));
select array_prepend(array(CAST(NULL AS String)), CAST(NULL AS String));
72 changes: 72 additions & 0 deletions sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out
Original file line number Diff line number Diff line change
Expand Up @@ -784,3 +784,75 @@ select array_append(array(CAST(NULL AS String)), CAST(NULL AS String))
struct<array_append(array(CAST(NULL AS STRING)), CAST(NULL AS STRING)):array<string>>
-- !query output
[null,null]


-- !query
select array_prepend(array(1, 2, 3), 4)
-- !query schema
struct<array_prepend(array(1, 2, 3), 4):array<int>>
-- !query output
[4,1,2,3]


-- !query
select array_prepend(array('a', 'b', 'c'), 'd')
-- !query schema
struct<array_prepend(array(a, b, c), d):array<string>>
-- !query output
["d","a","b","c"]


-- !query
select array_prepend(array(1, 2, 3, NULL), NULL)
-- !query schema
struct<array_prepend(array(1, 2, 3, NULL), NULL):array<int>>
-- !query output
[null,1,2,3,null]


-- !query
select array_prepend(array('a', 'b', 'c', NULL), NULL)
-- !query schema
struct<array_prepend(array(a, b, c, NULL), NULL):array<string>>
-- !query output
[null,"a","b","c",null]


-- !query
select array_prepend(CAST(null AS ARRAY<String>), 'a')
-- !query schema
struct<array_prepend(NULL, a):array<string>>
-- !query output
NULL


-- !query
select array_prepend(CAST(null AS ARRAY<String>), CAST(null as String))
-- !query schema
struct<array_prepend(NULL, CAST(NULL AS STRING)):array<string>>
-- !query output
NULL


-- !query
select array_prepend(array(), 1)
-- !query schema
struct<array_prepend(array(), 1):array<int>>
-- !query output
[1]


-- !query
select array_prepend(CAST(array() AS ARRAY<String>), CAST(NULL AS String))
-- !query schema
struct<array_prepend(array(), CAST(NULL AS STRING)):array<string>>
-- !query output
[null]


-- !query
select array_prepend(array(CAST(NULL AS String)), CAST(NULL AS String))
-- !query schema
struct<array_prepend(array(CAST(NULL AS STRING)), CAST(NULL AS STRING)):array<string>>
-- !query output
[null,null]
Loading