Skip to content

Commit 1e7e47d

Browse files
navinvishyzhengruifeng
authored andcommitted
[SPARK-41233][SQL][PYTHON] Add array_prepend function
### What changes were proposed in this pull request? Adds a new array function array_prepend to catalyst. ### Why are the changes needed? This adds a function that exists in many SQL implementations, specifically Snowflake: https://docs.snowflake.com/en/developer-guide/snowpark/reference/python/api/snowflake.snowpark.functions.array_prepend.html ### Does this PR introduce _any_ user-facing change? Yes. ### How was this patch tested? Added unit tests. Closes #38947 from navinvishy/array-prepend. Lead-authored-by: Navin Viswanath <[email protected]> Co-authored-by: navinvishy <[email protected]> Signed-off-by: Ruifeng Zheng <[email protected]>
1 parent d0c174c commit 1e7e47d

File tree

12 files changed

+459
-1
lines changed

12 files changed

+459
-1
lines changed

connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ object CheckConnectJvmClientCompatibility {
177177
ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.broadcast"),
178178
ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.typedlit"),
179179
ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.typedLit"),
180+
ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.array_prepend"),
180181

181182
// RelationalGroupedDataset
182183
ProblemFilters.exclude[Problem]("org.apache.spark.sql.RelationalGroupedDataset.apply"),

python/docs/source/reference/pyspark.sql/functions.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ Collection Functions
159159
array_sort
160160
array_insert
161161
array_remove
162+
array_prepend
162163
array_distinct
163164
array_intersect
164165
array_union

python/pyspark/sql/functions.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7631,6 +7631,36 @@ def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column:
76317631
return _invoke_function_over_columns("get", col, index)
76327632

76337633

7634+
@try_remote_functions
7635+
def array_prepend(col: "ColumnOrName", value: Any) -> Column:
7636+
"""
7637+
Collection function: Returns an array containing element as
7638+
well as all elements from array. The new element is positioned
7639+
at the beginning of the array.
7640+
7641+
.. versionadded:: 3.5.0
7642+
7643+
Parameters
7644+
----------
7645+
col : :class:`~pyspark.sql.Column` or str
7646+
name of column containing array
7647+
value :
7648+
a literal value, or a :class:`~pyspark.sql.Column` expression.
7649+
7650+
Returns
7651+
-------
7652+
:class:`~pyspark.sql.Column`
7653+
an array excluding given value.
7654+
7655+
Examples
7656+
--------
7657+
>>> df = spark.createDataFrame([([2, 3, 4],), ([],)], ['data'])
7658+
>>> df.select(array_prepend(df.data, 1)).collect()
7659+
[Row(array_prepend(data, 1)=[1, 2, 3, 4]), Row(array_prepend(data, 1)=[1])]
7660+
"""
7661+
return _invoke_function_over_columns("array_prepend", col, lit(value))
7662+
7663+
76347664
@try_remote_functions
76357665
def array_remove(col: "ColumnOrName", element: Any) -> Column:
76367666
"""

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -697,6 +697,7 @@ object FunctionRegistry {
697697
expression[Sequence]("sequence"),
698698
expression[ArrayRepeat]("array_repeat"),
699699
expression[ArrayRemove]("array_remove"),
700+
expression[ArrayPrepend]("array_prepend"),
700701
expression[ArrayDistinct]("array_distinct"),
701702
expression[ArrayTransform]("transform"),
702703
expression[MapFilter]("map_filter"),

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1399,6 +1399,152 @@ case class ArrayContains(left: Expression, right: Expression)
13991399
copy(left = newLeft, right = newRight)
14001400
}
14011401

1402+
// scalastyle:off line.size.limit
1403+
@ExpressionDescription(
1404+
usage = """
1405+
_FUNC_(array, element) - Add the element at the beginning of the array passed as first
1406+
argument. Type of element should be the same as the type of the elements of the array.
1407+
Null element is also prepended to the array. But if the array passed is NULL
1408+
output is NULL
1409+
""",
1410+
examples = """
1411+
Examples:
1412+
> SELECT _FUNC_(array('b', 'd', 'c', 'a'), 'd');
1413+
["d","b","d","c","a"]
1414+
> SELECT _FUNC_(array(1, 2, 3, null), null);
1415+
[null,1,2,3,null]
1416+
> SELECT _FUNC_(CAST(null as Array<Int>), 2);
1417+
NULL
1418+
""",
1419+
group = "array_funcs",
1420+
since = "3.5.0")
1421+
case class ArrayPrepend(left: Expression, right: Expression)
1422+
extends BinaryExpression
1423+
with ImplicitCastInputTypes
1424+
with ComplexTypeMergingExpression
1425+
with QueryErrorsBase {
1426+
1427+
override def nullable: Boolean = left.nullable
1428+
1429+
@transient protected lazy val elementType: DataType =
1430+
inputTypes.head.asInstanceOf[ArrayType].elementType
1431+
1432+
override def eval(input: InternalRow): Any = {
1433+
val value1 = left.eval(input)
1434+
if (value1 == null) {
1435+
null
1436+
} else {
1437+
val value2 = right.eval(input)
1438+
nullSafeEval(value1, value2)
1439+
}
1440+
}
1441+
override def nullSafeEval(arr: Any, elementData: Any): Any = {
1442+
val arrayData = arr.asInstanceOf[ArrayData]
1443+
val numberOfElements = arrayData.numElements() + 1
1444+
if (numberOfElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
1445+
throw QueryExecutionErrors.concatArraysWithElementsExceedLimitError(numberOfElements)
1446+
}
1447+
val finalData = new Array[Any](numberOfElements)
1448+
finalData.update(0, elementData)
1449+
arrayData.foreach(elementType, (i: Int, v: Any) => finalData.update(i + 1, v))
1450+
new GenericArrayData(finalData)
1451+
}
1452+
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
1453+
val leftGen = left.genCode(ctx)
1454+
val rightGen = right.genCode(ctx)
1455+
val f = (arr: String, value: String) => {
1456+
val newArraySize = s"$arr.numElements() + 1"
1457+
val newArray = ctx.freshName("newArray")
1458+
val i = ctx.freshName("i")
1459+
val iPlus1 = s"$i+1"
1460+
val zero = "0"
1461+
val allocation = CodeGenerator.createArrayData(
1462+
newArray,
1463+
elementType,
1464+
newArraySize,
1465+
s" $prettyName failed.")
1466+
val assignment =
1467+
CodeGenerator.createArrayAssignment(newArray, elementType, arr, iPlus1, i, false)
1468+
val newElemAssignment =
1469+
CodeGenerator.setArrayElement(newArray, elementType, zero, value, Some(rightGen.isNull))
1470+
s"""
1471+
|$allocation
1472+
|$newElemAssignment
1473+
|for (int $i = 0; $i < $arr.numElements(); $i ++) {
1474+
| $assignment
1475+
|}
1476+
|${ev.value} = $newArray;
1477+
|""".stripMargin
1478+
}
1479+
val resultCode = f(leftGen.value, rightGen.value)
1480+
if(nullable) {
1481+
val nullSafeEval = leftGen.code + rightGen.code + ctx.nullSafeExec(nullable, leftGen.isNull) {
1482+
s"""
1483+
|${ev.isNull} = false;
1484+
|${resultCode}
1485+
|""".stripMargin
1486+
}
1487+
ev.copy(code =
1488+
code"""
1489+
|boolean ${ev.isNull} = true;
1490+
|${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
1491+
|$nullSafeEval
1492+
""".stripMargin
1493+
)
1494+
} else {
1495+
ev.copy(code =
1496+
code"""
1497+
|${leftGen.code}
1498+
|${rightGen.code}
1499+
|${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
1500+
|$resultCode
1501+
""".stripMargin, isNull = FalseLiteral)
1502+
}
1503+
}
1504+
1505+
override def prettyName: String = "array_prepend"
1506+
1507+
override protected def withNewChildrenInternal(
1508+
newLeft: Expression, newRight: Expression): ArrayPrepend =
1509+
copy(left = newLeft, right = newRight)
1510+
1511+
override def dataType: DataType = if (right.nullable) left.dataType.asNullable else left.dataType
1512+
1513+
override def checkInputDataTypes(): TypeCheckResult = {
1514+
(left.dataType, right.dataType) match {
1515+
case (ArrayType(e1, _), e2) if e1.sameType(e2) => TypeCheckResult.TypeCheckSuccess
1516+
case (ArrayType(e1, _), e2) => DataTypeMismatch(
1517+
errorSubClass = "ARRAY_FUNCTION_DIFF_TYPES",
1518+
messageParameters = Map(
1519+
"functionName" -> toSQLId(prettyName),
1520+
"leftType" -> toSQLType(left.dataType),
1521+
"rightType" -> toSQLType(right.dataType),
1522+
"dataType" -> toSQLType(ArrayType)
1523+
))
1524+
case _ =>
1525+
DataTypeMismatch(
1526+
errorSubClass = "UNEXPECTED_INPUT_TYPE",
1527+
messageParameters = Map(
1528+
"paramIndex" -> "0",
1529+
"requiredType" -> toSQLType(ArrayType),
1530+
"inputSql" -> toSQLExpr(left),
1531+
"inputType" -> toSQLType(left.dataType)
1532+
)
1533+
)
1534+
}
1535+
}
1536+
override def inputTypes: Seq[AbstractDataType] = {
1537+
(left.dataType, right.dataType) match {
1538+
case (ArrayType(e1, hasNull), e2) =>
1539+
TypeCoercion.findTightestCommonType(e1, e2) match {
1540+
case Some(dt) => Seq(ArrayType(dt, hasNull), dt)
1541+
case _ => Seq.empty
1542+
}
1543+
case _ => Seq.empty
1544+
}
1545+
}
1546+
}
1547+
14021548
/**
14031549
* Checks if the two arrays contain at least one common element.
14041550
*/

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1855,6 +1855,50 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
18551855
checkEvaluation(ArrayRepeat(Literal("hi"), Literal(null, IntegerType)), null)
18561856
}
18571857

1858+
test("SPARK-41233: ArrayPrepend") {
1859+
val a0 = Literal.create(Seq(1, 2, 3, 4), ArrayType(IntegerType))
1860+
val a1 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType))
1861+
val a2 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType))
1862+
val a3 = Literal.create(null, ArrayType(StringType))
1863+
1864+
checkEvaluation(ArrayPrepend(a0, Literal(0)), Seq(0, 1, 2, 3, 4))
1865+
checkEvaluation(ArrayPrepend(a1, Literal("a")), Seq("a", "a", "b", "c"))
1866+
checkEvaluation(ArrayPrepend(a2, Literal(1)), Seq(1))
1867+
checkEvaluation(ArrayPrepend(a2, Literal(null, IntegerType)), Seq(null))
1868+
checkEvaluation(ArrayPrepend(a3, Literal("a")), null)
1869+
checkEvaluation(ArrayPrepend(a3, Literal(null, StringType)), null)
1870+
1871+
// complex data types
1872+
val data = Seq[Array[Byte]](
1873+
Array[Byte](5, 6),
1874+
Array[Byte](1, 2),
1875+
Array[Byte](1, 2),
1876+
Array[Byte](5, 6))
1877+
val b0 = Literal.create(
1878+
data,
1879+
ArrayType(BinaryType))
1880+
val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), null), ArrayType(BinaryType))
1881+
val nullBinary = Literal.create(null, BinaryType)
1882+
// Calling ArrayPrepend with a null element should result in NULL being prepended to the array
1883+
val dataWithNullPrepended = null +: data
1884+
checkEvaluation(ArrayPrepend(b0, nullBinary), dataWithNullPrepended)
1885+
val dataToPrepend1 = Literal.create(Array[Byte](5, 6), BinaryType)
1886+
checkEvaluation(
1887+
ArrayPrepend(b1, dataToPrepend1),
1888+
Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](2, 1), null))
1889+
1890+
val c0 = Literal.create(
1891+
Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)),
1892+
ArrayType(ArrayType(IntegerType)))
1893+
val dataToPrepend2 = Literal.create(Seq[Int](5, 6), ArrayType(IntegerType))
1894+
checkEvaluation(
1895+
ArrayPrepend(c0, dataToPrepend2),
1896+
Seq(Seq[Int](5, 6), Seq[Int](1, 2), Seq[Int](3, 4)))
1897+
checkEvaluation(
1898+
ArrayPrepend(c0, Literal.create(Seq.empty[Int], ArrayType(IntegerType))),
1899+
Seq(Seq.empty[Int], Seq[Int](1, 2), Seq[Int](3, 4)))
1900+
}
1901+
18581902
test("Array remove") {
18591903
val a0 = Literal.create(Seq(1, 2, 3, 2, 2, 5), ArrayType(IntegerType))
18601904
val a1 = Literal.create(Seq("b", "a", "a", "c", "b"), ArrayType(StringType))

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4044,6 +4044,17 @@ object functions {
40444044
ArrayCompact(column.expr)
40454045
}
40464046

4047+
/**
4048+
* Returns an array containing value as well as all elements from array. The new element is
4049+
* positioned at the beginning of the array.
4050+
*
4051+
* @group collection_funcs
4052+
* @since 3.5.0
4053+
*/
4054+
def array_prepend(column: Column, element: Any): Column = withExpr {
4055+
ArrayPrepend(column.expr, lit(element).expr)
4056+
}
4057+
40474058
/**
40484059
* Removes duplicate values from the array.
40494060
* @group collection_funcs

sql/core/src/test/resources/sql-functions/sql-expression-schema.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
| org.apache.spark.sql.catalyst.expressions.ArrayMax | array_max | SELECT array_max(array(1, 20, null, 3)) | struct<array_max(array(1, 20, NULL, 3)):int> |
2727
| org.apache.spark.sql.catalyst.expressions.ArrayMin | array_min | SELECT array_min(array(1, 20, null, 3)) | struct<array_min(array(1, 20, NULL, 3)):int> |
2828
| org.apache.spark.sql.catalyst.expressions.ArrayPosition | array_position | SELECT array_position(array(3, 2, 1), 1) | struct<array_position(array(3, 2, 1), 1):bigint> |
29+
| org.apache.spark.sql.catalyst.expressions.ArrayPrepend | array_prepend | SELECT array_prepend(array('b', 'd', 'c', 'a'), 'd') | struct<array_prepend(array(b, d, c, a), d):array<string>> |
2930
| org.apache.spark.sql.catalyst.expressions.ArrayRemove | array_remove | SELECT array_remove(array(1, 2, 3, null, 3), 3) | struct<array_remove(array(1, 2, 3, NULL, 3), 3):array<int>> |
3031
| org.apache.spark.sql.catalyst.expressions.ArrayRepeat | array_repeat | SELECT array_repeat('123', 2) | struct<array_repeat(123, 2):array<string>> |
3132
| org.apache.spark.sql.catalyst.expressions.ArraySize | array_size | SELECT array_size(array('b', 'd', 'c', 'a')) | struct<array_size(array(b, d, c, a)):int> |
@@ -421,4 +422,4 @@
421422
| org.apache.spark.sql.catalyst.expressions.xml.XPathList | xpath | SELECT xpath('<a><b>b1</b><b>b2</b><b>b3</b><c>c1</c><c>c2</c></a>','a/b/text()') | struct<xpath(<a><b>b1</b><b>b2</b><b>b3</b><c>c1</c><c>c2</c></a>, a/b/text()):array<string>> |
422423
| org.apache.spark.sql.catalyst.expressions.xml.XPathLong | xpath_long | SELECT xpath_long('<a><b>1</b><b>2</b></a>', 'sum(a/b)') | struct<xpath_long(<a><b>1</b><b>2</b></a>, sum(a/b)):bigint> |
423424
| org.apache.spark.sql.catalyst.expressions.xml.XPathShort | xpath_short | SELECT xpath_short('<a><b>1</b><b>2</b></a>', 'sum(a/b)') | struct<xpath_short(<a><b>1</b><b>2</b></a>, sum(a/b)):smallint> |
424-
| org.apache.spark.sql.catalyst.expressions.xml.XPathString | xpath_string | SELECT xpath_string('<a><b>b</b><c>cc</c></a>','a/c') | struct<xpath_string(<a><b>b</b><c>cc</c></a>, a/c):string> |
425+
| org.apache.spark.sql.catalyst.expressions.xml.XPathString | xpath_string | SELECT xpath_string('<a><b>b</b><c>cc</c></a>','a/c') | struct<xpath_string(<a><b>b</b><c>cc</c></a>, a/c):string> |

sql/core/src/test/resources/sql-tests/inputs/array.sql

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,3 +160,14 @@ select array_append(CAST(null AS ARRAY<String>), CAST(null as String));
160160
select array_append(array(), 1);
161161
select array_append(CAST(array() AS ARRAY<String>), CAST(NULL AS String));
162162
select array_append(array(CAST(NULL AS String)), CAST(NULL AS String));
163+
164+
-- function array_prepend
165+
select array_prepend(array(1, 2, 3), 4);
166+
select array_prepend(array('a', 'b', 'c'), 'd');
167+
select array_prepend(array(1, 2, 3, NULL), NULL);
168+
select array_prepend(array('a', 'b', 'c', NULL), NULL);
169+
select array_prepend(CAST(null AS ARRAY<String>), 'a');
170+
select array_prepend(CAST(null AS ARRAY<String>), CAST(null as String));
171+
select array_prepend(array(), 1);
172+
select array_prepend(CAST(array() AS ARRAY<String>), CAST(NULL AS String));
173+
select array_prepend(array(CAST(NULL AS String)), CAST(NULL AS String));

sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -784,3 +784,75 @@ select array_append(array(CAST(NULL AS String)), CAST(NULL AS String))
784784
struct<array_append(array(CAST(NULL AS STRING)), CAST(NULL AS STRING)):array<string>>
785785
-- !query output
786786
[null,null]
787+
788+
789+
-- !query
790+
select array_prepend(array(1, 2, 3), 4)
791+
-- !query schema
792+
struct<array_prepend(array(1, 2, 3), 4):array<int>>
793+
-- !query output
794+
[4,1,2,3]
795+
796+
797+
-- !query
798+
select array_prepend(array('a', 'b', 'c'), 'd')
799+
-- !query schema
800+
struct<array_prepend(array(a, b, c), d):array<string>>
801+
-- !query output
802+
["d","a","b","c"]
803+
804+
805+
-- !query
806+
select array_prepend(array(1, 2, 3, NULL), NULL)
807+
-- !query schema
808+
struct<array_prepend(array(1, 2, 3, NULL), NULL):array<int>>
809+
-- !query output
810+
[null,1,2,3,null]
811+
812+
813+
-- !query
814+
select array_prepend(array('a', 'b', 'c', NULL), NULL)
815+
-- !query schema
816+
struct<array_prepend(array(a, b, c, NULL), NULL):array<string>>
817+
-- !query output
818+
[null,"a","b","c",null]
819+
820+
821+
-- !query
822+
select array_prepend(CAST(null AS ARRAY<String>), 'a')
823+
-- !query schema
824+
struct<array_prepend(NULL, a):array<string>>
825+
-- !query output
826+
NULL
827+
828+
829+
-- !query
830+
select array_prepend(CAST(null AS ARRAY<String>), CAST(null as String))
831+
-- !query schema
832+
struct<array_prepend(NULL, CAST(NULL AS STRING)):array<string>>
833+
-- !query output
834+
NULL
835+
836+
837+
-- !query
838+
select array_prepend(array(), 1)
839+
-- !query schema
840+
struct<array_prepend(array(), 1):array<int>>
841+
-- !query output
842+
[1]
843+
844+
845+
-- !query
846+
select array_prepend(CAST(array() AS ARRAY<String>), CAST(NULL AS String))
847+
-- !query schema
848+
struct<array_prepend(array(), CAST(NULL AS STRING)):array<string>>
849+
-- !query output
850+
[null]
851+
852+
853+
-- !query
854+
select array_prepend(array(CAST(NULL AS String)), CAST(NULL AS String))
855+
-- !query schema
856+
struct<array_prepend(array(CAST(NULL AS STRING)), CAST(NULL AS STRING)):array<string>>
857+
-- !query output
858+
[null,null]

0 commit comments

Comments
 (0)