From 78f960bfb5465b1ca85a90deb18447f223156ae1 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Fri, 27 Jan 2023 20:34:49 -0800 Subject: [PATCH 01/68] Adds a array_prepend expression to catalyst --- .../reference/pyspark.sql/functions.rst | 1 + python/pyspark/sql/functions.py | 27 ++++- .../catalyst/analysis/FunctionRegistry.scala | 12 ++ .../expressions/collectionOperations.scala | 113 ++++++++++++++++++ .../CollectionExpressionsSuite.scala | 41 +++++++ .../org/apache/spark/sql/functions.scala | 10 ++ .../spark/sql/DataFrameFunctionsSuite.scala | 52 ++++++++ 7 files changed, 255 insertions(+), 1 deletion(-) diff --git a/python/docs/source/reference/pyspark.sql/functions.rst b/python/docs/source/reference/pyspark.sql/functions.rst index ddc8eab90f77..13a64b721542 100644 --- a/python/docs/source/reference/pyspark.sql/functions.rst +++ b/python/docs/source/reference/pyspark.sql/functions.rst @@ -158,6 +158,7 @@ Collection Functions array_append array_sort array_remove + array_prepend array_distinct array_intersect array_union diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 3426f2bdaf6c..edd9f24b92ec 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7618,6 +7618,32 @@ def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: return _invoke_function_over_columns("get", col, index) +def array_prepend(col: "ColumnOrName", element: Any) -> Column: + """ + Collection function: Returns an array containing value as well as all elements from array. + The new element is positioned at the beginning of the array. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + name of column containing array + element : + element to be prepended to the array + + Returns + ------- + :class:`~pyspark.sql.Column` + an array excluding given value. + + Examples + -------- + >>> df = spark.createDataFrame([([2, 3, 4],), ([],)], ['data']) + >>> df.select(array_prepend(df.data, 1)).collect() + [Row(array_prepend(data, 1)=[1, 2, 3, 4]), Row(array_prepend(data, 1)=[1])] + """ + return _invoke_function("array_prepend", _to_java_column(col), element) @try_remote_functions def array_remove(col: "ColumnOrName", element: Any) -> Column: @@ -7649,7 +7675,6 @@ def array_remove(col: "ColumnOrName", element: Any) -> Column: """ return _invoke_function("array_remove", _to_java_column(col), element) - @try_remote_functions def array_distinct(col: "ColumnOrName") -> Column: """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 99bab1003767..a1750d6d45d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -694,6 +694,7 @@ object FunctionRegistry { expression[Sequence]("sequence"), expression[ArrayRepeat]("array_repeat"), expression[ArrayRemove]("array_remove"), + expression[ArrayPrepend]("array_prepend"), expression[ArrayDistinct]("array_distinct"), expression[ArrayTransform]("transform"), expression[MapFilter]("map_filter"), @@ -967,6 +968,7 @@ object TableFunctionRegistry { : (String, (ExpressionInfo, TableFunctionBuilder)) = { val (info, builder) = FunctionRegistryBase.build[T](name, since = None) val newBuilder = (expressions: Seq[Expression]) => { +<<<<<<< HEAD val generator = builder(expressions) assert(generator.isInstanceOf[Generator]) Generate( @@ -976,6 +978,16 @@ object TableFunctionRegistry { qualifier = None, generatorOutput = Nil, child = OneRowRelation()) +======= + try { + builder(expressions) + } catch { + case e: AnalysisException => + val argTypes = expressions.map(_.dataType.typeName).mkString(", ") + throw QueryCompilationErrors.cannotApplyTableValuedFunctionError( + name, argTypes, info.getUsage, e.getMessage) + } +>>>>>>> Revert "SPARK-41231: Adds an array_prepend function to catalyst" } (name, (info, newBuilder)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index ca3982f54c8b..ff0bce168787 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1399,6 +1399,119 @@ case class ArrayContains(left: Expression, right: Expression) copy(left = newLeft, right = newRight) } +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = + "_FUNC_(array, value) - Returns an array containing value as well as all elements from array. The new element is positioned at the beginning of the array.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), 4); + [4, 1, 2, 3] + """, + group = "array_funcs", + since = "3.4.0") +case class ArrayPrepend(left: Expression, right: Expression) + extends BinaryExpression + with ImplicitCastInputTypes + with NullIntolerant + with QueryErrorsBase { + + override def nullSafeEval(arr: Any, value: Any): Any = { + val numberOfElements = arr.asInstanceOf[ArrayData].numElements() + if (numberOfElements + 1 > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + throw QueryExecutionErrors.concatArraysWithElementsExceedLimitError(numberOfElements) + } + val newArray = new Array[Any](numberOfElements + 1) + newArray(0) = value + var pos = 1 + arr + .asInstanceOf[ArrayData] + .foreach( + right.dataType, + (i, v) => { + newArray(pos) = v + pos += 1 + }) + new GenericArrayData(newArray) + } + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen( + ctx, + ev, + (arr, value) => { + val newArraySize = ctx.freshName("newArraySize") + val newArray = ctx.freshName("newArray") + val i = ctx.freshName("i") + val pos = ctx.freshName("pos") + val allocation = CodeGenerator.createArrayData( + newArray, + right.dataType, + newArraySize, + s" $prettyName failed.") + val assignment = + CodeGenerator.createArrayAssignment(newArray, right.dataType, arr, pos, i, false) + val newElemAssignment = + CodeGenerator.setArrayElement(newArray, right.dataType, pos, value) + s""" + |int $pos = 0; + |int $newArraySize = $arr.numElements() + 1; + |$allocation + |$newElemAssignment + |$pos = $pos + 1; + |for (int $i = 0; $i < $arr.numElements(); $i ++) { + | $assignment + | $pos = $pos + 1; + |} + |${ev.value} = $newArray; + |""".stripMargin + }) + } + + override def prettyName: String = "array_prepend" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): ArrayPrepend = + copy(left = newLeft, right = newRight) + override def dataType: DataType = left.dataType + override def checkInputDataTypes(): TypeCheckResult = { + (left.dataType, right.dataType) match { + case (_, NullType) | (NullType, _) => + DataTypeMismatch( + errorSubClass = "NULL_TYPE", + messageParameters = Map("functionName" -> toSQLId(prettyName))) + case (l, _) if !ArrayType.acceptsType(l) => + DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> "1", + "requiredType" -> toSQLType(ArrayType), + "inputSql" -> toSQLExpr(left), + "inputType" -> toSQLType(left.dataType))) + case (ArrayType(e1, _), e2) if e1.sameType(e2) => + TypeUtils.checkForOrderingExpr(e2, prettyName) + case _ => + DataTypeMismatch( + errorSubClass = "ARRAY_FUNCTION_DIFF_TYPES", + messageParameters = Map( + "functionName" -> toSQLId(prettyName), + "dataType" -> toSQLType(ArrayType), + "leftType" -> toSQLType(left.dataType), + "rightType" -> toSQLType(right.dataType))) + } + } + override def inputTypes: Seq[AbstractDataType] = { + (left.dataType, right.dataType) match { + case (_, NullType) => Seq.empty + case (ArrayType(e1, hasNull), e2) => + TypeCoercion.findWiderTypeWithoutStringPromotionForTwo(e1, e2) match { + case Some(dt) => Seq(ArrayType(dt, hasNull), dt) + case _ => Seq.empty + } + case _ => Seq.empty + } + } +} + /** * Checks if the two arrays contain at least one common element. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index d83739df38d1..a4c3ffeeb0d7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1840,6 +1840,47 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayRepeat(Literal("hi"), Literal(null, IntegerType)), null) } + test("SPARK-41233: ArrayPrepend") { + val a0 = Literal.create(Seq(1, 2, 3, 4), ArrayType(IntegerType)) + val a1 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType)) + val a2 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType)) + val a3 = Literal.create(null, ArrayType(StringType)) + + checkEvaluation(ArrayPrepend(a0, Literal(0)), Seq(0, 1, 2, 3, 4)) + checkEvaluation(ArrayPrepend(a1, Literal("a")), Seq("a", "a", "b", "c")) + checkEvaluation(ArrayPrepend(a2, Literal(1)), Seq(1)) + checkEvaluation(ArrayPrepend(a2, Literal(null, IntegerType)), null) + checkEvaluation(ArrayPrepend(a3, Literal("a")), null) + checkEvaluation(ArrayPrepend(a3, Literal(null, StringType)), null) + + // complex data types + val b0 = Literal.create( + Seq[Array[Byte]]( + Array[Byte](5, 6), + Array[Byte](1, 2), + Array[Byte](1, 2), + Array[Byte](5, 6)), + ArrayType(BinaryType)) + val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), null), ArrayType(BinaryType)) + val nullBinary = Literal.create(null, BinaryType) + checkEvaluation(ArrayPrepend(b0, nullBinary), null) + val dataToPrepend1 = Literal.create(Array[Byte](5, 6), BinaryType) + checkEvaluation( + ArrayPrepend(b1, dataToPrepend1), + Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](2, 1), null)) + + val c0 = Literal.create( + Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)), + ArrayType(ArrayType(IntegerType))) + val dataToPrepend2 = Literal.create(Seq[Int](5, 6), ArrayType(IntegerType)) + checkEvaluation( + ArrayPrepend(c0, dataToPrepend2), + Seq(Seq[Int](5, 6), Seq[Int](1, 2), Seq[Int](3, 4))) + checkEvaluation( + ArrayPrepend(c0, Literal.create(Seq.empty[Int], ArrayType(IntegerType))), + Seq(Seq.empty[Int], Seq[Int](1, 2), Seq[Int](3, 4))) + } + test("Array remove") { val a0 = Literal.create(Seq(1, 2, 3, 2, 2, 5), ArrayType(IntegerType)) val a1 = Literal.create(Seq("b", "a", "a", "c", "b"), ArrayType(StringType)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 3d5547ead831..be0c99aeec71 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -4042,6 +4042,16 @@ object functions { */ def array_compact(column: Column): Column = withExpr { ArrayCompact(column.expr) + + /** + * Returns an array containing value as well as all elements from array. The new element is + * positioned at the beginning of the array. + * + * @group collection_funcs + * @since 3.4.0 + */ + def array_prepend(column: Column, element: Any): Column = withExpr { + ArrayPrepend(column.expr, lit(element).expr) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 231c9562511a..d8b54dbc76ee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2651,6 +2651,58 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { ) } + test("SPARK-41233: array prepend") { + val df = Seq( + (Array[Int](2, 3, 4), Array("b", "c", "d"), Array("", ""), 2), + (Array.empty[Int], Array.empty[String], Array.empty[String], 2), + (null, null, null, 2)).toDF("a", "b", "c", "d") + checkAnswer( + df.select(array_prepend($"a", 1), array_prepend($"b", "a"), array_prepend($"c", "")), + Seq( + Row(Seq(1, 2, 3, 4), Seq("a", "b", "c", "d"), Seq("", "", "")), + Row(Seq(1), Seq("a"), Seq("")), + Row(null, null, null))) + checkAnswer( + df.select(array_prepend($"a", $"d")), + Seq(Row(Seq(2, 2, 3, 4)), Row(Seq(2)), Row(null))) + checkAnswer( + df.selectExpr("array_prepend(a, d)"), + Seq(Row(Seq(2, 2, 3, 4)), Row(Seq(2)), Row(null))) + checkAnswer( + OneRowRelation().selectExpr("array_prepend(array(1, 2), 1.23D)"), + Seq(Row(Seq(1.23, 1.0, 2.0)))) + checkAnswer( + df.selectExpr("array_prepend(a, 1)", "array_prepend(b, \"a\")", "array_prepend(c, \"\")"), + Seq( + Row(Seq(1, 2, 3, 4), Seq("a", "b", "c", "d"), Seq("", "", "")), + Row(Seq(1), Seq("a"), Seq("")), + Row(null, null, null))) + checkError( + exception = intercept[AnalysisException] { + Seq(("a string element", "a")).toDF().selectExpr("array_prepend(_1, _2)") + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + parameters = Map( + "paramIndex" -> "1", + "sqlExpr" -> "\"array_prepend(_1, _2)\"", + "inputSql" -> "\"_1\"", + "inputType" -> "\"STRING\"", + "requiredType" -> "\"ARRAY\""), + queryContext = Array(ExpectedContext("", "", 0, 20, "array_prepend(_1, _2)"))) + checkError( + exception = intercept[AnalysisException] { + OneRowRelation().selectExpr("array_prepend(array(1, 2), '1')") + }, + errorClass = "DATATYPE_MISMATCH.ARRAY_FUNCTION_DIFF_TYPES", + parameters = Map( + "sqlExpr" -> "\"array_prepend(array(1, 2), 1)\"", + "functionName" -> "`array_prepend`", + "dataType" -> "\"ARRAY\"", + "leftType" -> "\"ARRAY\"", + "rightType" -> "\"STRING\""), + queryContext = Array(ExpectedContext("", "", 0, 30, "array_prepend(array(1, 2), '1')"))) + } + test("array remove") { val df = Seq( (Array[Int](2, 1, 2, 3), Array("a", "b", "c", "a"), Array("", ""), 2), From f2d4f68ab2db1ba7c1c896f4474d63caab585297 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Fri, 27 Jan 2023 20:36:47 -0800 Subject: [PATCH 02/68] Fix null handling --- .../expressions/collectionOperations.scala | 122 +++++++++++------- .../CollectionExpressionsSuite.scala | 19 +-- .../org/apache/spark/sql/functions.scala | 15 +++ 3 files changed, 101 insertions(+), 55 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index ff0bce168787..000e61bb6a21 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1413,9 +1413,19 @@ case class ArrayContains(left: Expression, right: Expression) case class ArrayPrepend(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes - with NullIntolerant with QueryErrorsBase { + override def nullable: Boolean = left.nullable + + override def eval(input: InternalRow): Any = { + val value1 = left.eval(input) + if (value1 == null) { + null + } else { + val value2 = right.eval(input) + nullSafeEval(value1, value2) + } + } override def nullSafeEval(arr: Any, value: Any): Any = { val numberOfElements = arr.asInstanceOf[ArrayData].numElements() if (numberOfElements + 1 > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { @@ -1435,36 +1445,57 @@ case class ArrayPrepend(left: Expression, right: Expression) new GenericArrayData(newArray) } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen( - ctx, - ev, - (arr, value) => { - val newArraySize = ctx.freshName("newArraySize") - val newArray = ctx.freshName("newArray") - val i = ctx.freshName("i") - val pos = ctx.freshName("pos") - val allocation = CodeGenerator.createArrayData( - newArray, - right.dataType, - newArraySize, - s" $prettyName failed.") - val assignment = - CodeGenerator.createArrayAssignment(newArray, right.dataType, arr, pos, i, false) - val newElemAssignment = - CodeGenerator.setArrayElement(newArray, right.dataType, pos, value) + val leftGen = left.genCode(ctx) + val rightGen = right.genCode(ctx) + val f = (arr: String, value: String) => { + val newArraySize = ctx.freshName("newArraySize") + val newArray = ctx.freshName("newArray") + val i = ctx.freshName("i") + val pos = ctx.freshName("pos") + val allocation = CodeGenerator.createArrayData( + newArray, + right.dataType, + newArraySize, + s" $prettyName failed.") + val assignment = + CodeGenerator.createArrayAssignment(newArray, right.dataType, arr, pos, i, false) + val newElemAssignment = + CodeGenerator.setArrayElement(newArray, right.dataType, pos, value, Some(rightGen.isNull)) + s""" + |int $pos = 0; + |int $newArraySize = $arr.numElements() + 1; + |$allocation + |$newElemAssignment + |$pos = $pos + 1; + |for (int $i = 0; $i < $arr.numElements(); $i ++) { + | $assignment + | $pos = $pos + 1; + |} + |${ev.value} = $newArray; + |""".stripMargin + } + val resultCode = f(leftGen.value, rightGen.value) + if(nullable) { + val nullSafeEval = leftGen.code + rightGen.code + ctx.nullSafeExec(nullable, leftGen.isNull) { s""" - |int $pos = 0; - |int $newArraySize = $arr.numElements() + 1; - |$allocation - |$newElemAssignment - |$pos = $pos + 1; - |for (int $i = 0; $i < $arr.numElements(); $i ++) { - | $assignment - | $pos = $pos + 1; - |} - |${ev.value} = $newArray; + |${ev.isNull} = false; + |${resultCode} |""".stripMargin - }) + } + ev.copy(code = + code""" + boolean ${ev.isNull} = true; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + $nullSafeEval + """) + } else { + ev.copy(code = + code""" + ${leftGen.code} + ${rightGen.code} + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + $resultCode""", isNull = FalseLiteral) + } } override def prettyName: String = "array_prepend" @@ -1472,31 +1503,30 @@ case class ArrayPrepend(left: Expression, right: Expression) override protected def withNewChildrenInternal( newLeft: Expression, newRight: Expression): ArrayPrepend = copy(left = newLeft, right = newRight) + override def dataType: DataType = left.dataType + override def checkInputDataTypes(): TypeCheckResult = { (left.dataType, right.dataType) match { - case (_, NullType) | (NullType, _) => - DataTypeMismatch( - errorSubClass = "NULL_TYPE", - messageParameters = Map("functionName" -> toSQLId(prettyName))) - case (l, _) if !ArrayType.acceptsType(l) => + case (ArrayType(e1, _), e2) if e1.sameType(e2) => TypeCheckResult.TypeCheckSuccess + case (ArrayType(e1, _), e2) => DataTypeMismatch( + errorSubClass = "ARRAY_FUNCTION_DIFF_TYPES", + messageParameters = Map( + "functionName" -> toSQLId(prettyName), + "leftType" -> toSQLType(left.dataType), + "rightType" -> toSQLType(right.dataType), + "dataType" -> toSQLType(ArrayType) + )) + case _ => DataTypeMismatch( errorSubClass = "UNEXPECTED_INPUT_TYPE", messageParameters = Map( - "paramIndex" -> "1", + "paramIndex" -> "0", "requiredType" -> toSQLType(ArrayType), "inputSql" -> toSQLExpr(left), - "inputType" -> toSQLType(left.dataType))) - case (ArrayType(e1, _), e2) if e1.sameType(e2) => - TypeUtils.checkForOrderingExpr(e2, prettyName) - case _ => - DataTypeMismatch( - errorSubClass = "ARRAY_FUNCTION_DIFF_TYPES", - messageParameters = Map( - "functionName" -> toSQLId(prettyName), - "dataType" -> toSQLType(ArrayType), - "leftType" -> toSQLType(left.dataType), - "rightType" -> toSQLType(right.dataType))) + "inputType" -> toSQLType(left.dataType) + ) + ) } } override def inputTypes: Seq[AbstractDataType] = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index a4c3ffeeb0d7..b443a1c7f5aa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -20,10 +20,8 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} import java.time.{Duration, LocalDateTime, Period} import java.util.TimeZone - import scala.language.implicitConversions import scala.util.Random - import org.apache.spark.{SparkFunSuite, SparkRuntimeException} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow @@ -1849,21 +1847,24 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayPrepend(a0, Literal(0)), Seq(0, 1, 2, 3, 4)) checkEvaluation(ArrayPrepend(a1, Literal("a")), Seq("a", "a", "b", "c")) checkEvaluation(ArrayPrepend(a2, Literal(1)), Seq(1)) - checkEvaluation(ArrayPrepend(a2, Literal(null, IntegerType)), null) + checkEvaluation(ArrayPrepend(a2, Literal(null, IntegerType)), Seq(null)) checkEvaluation(ArrayPrepend(a3, Literal("a")), null) checkEvaluation(ArrayPrepend(a3, Literal(null, StringType)), null) // complex data types + val data = Seq[Array[Byte]]( + Array[Byte](5, 6), + Array[Byte](1, 2), + Array[Byte](1, 2), + Array[Byte](5, 6)) val b0 = Literal.create( - Seq[Array[Byte]]( - Array[Byte](5, 6), - Array[Byte](1, 2), - Array[Byte](1, 2), - Array[Byte](5, 6)), + data, ArrayType(BinaryType)) val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), null), ArrayType(BinaryType)) val nullBinary = Literal.create(null, BinaryType) - checkEvaluation(ArrayPrepend(b0, nullBinary), null) + // Calling ArrayPrepend with a null element should result in NULL being prepended to the array + val dataWithNullPrepended = null +: data + checkEvaluation(ArrayPrepend(b0, nullBinary), dataWithNullPrepended) val dataToPrepend1 = Literal.create(Array[Byte](5, 6), BinaryType) checkEvaluation( ArrayPrepend(b1, dataToPrepend1), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index be0c99aeec71..23f78e532ec0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -4050,6 +4050,21 @@ object functions { * @group collection_funcs * @since 3.4.0 */ + def array_prepend(column: Column, element: Any): Column = withExpr { + ArrayPrepend(column.expr, lit(element).expr) + + /** + * Remove all null elements from the given array. + * + * @group collection_funcs + * @since 3.4.0 + */ + def array_compact(column: Column): Column = withExpr { + ArrayCompact(column.expr) + /** + * Returns an array containing value as well as all elements from array.The + * new element is positioned at the beginning of the array. + */ def array_prepend(column: Column, element: Any): Column = withExpr { ArrayPrepend(column.expr, lit(element).expr) } From 85a8b4ce9b9bed366df5de073eb8d8f2c0c37b99 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Fri, 27 Jan 2023 20:41:01 -0800 Subject: [PATCH 03/68] Fix --- .../scala/org/apache/spark/sql/functions.scala | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 23f78e532ec0..79e42db22776 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -4042,7 +4042,7 @@ object functions { */ def array_compact(column: Column): Column = withExpr { ArrayCompact(column.expr) - + } /** * Returns an array containing value as well as all elements from array. The new element is * positioned at the beginning of the array. @@ -4050,21 +4050,6 @@ object functions { * @group collection_funcs * @since 3.4.0 */ - def array_prepend(column: Column, element: Any): Column = withExpr { - ArrayPrepend(column.expr, lit(element).expr) - - /** - * Remove all null elements from the given array. - * - * @group collection_funcs - * @since 3.4.0 - */ - def array_compact(column: Column): Column = withExpr { - ArrayCompact(column.expr) - /** - * Returns an array containing value as well as all elements from array.The - * new element is positioned at the beginning of the array. - */ def array_prepend(column: Column, element: Any): Column = withExpr { ArrayPrepend(column.expr, lit(element).expr) } From d7b601cedb9c8eb945e196d0e168d68f64bf0af0 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Fri, 27 Jan 2023 20:41:01 -0800 Subject: [PATCH 04/68] Fix --- .../catalyst/analysis/FunctionRegistry.scala | 11 ----------- .../scala/org/apache/spark/sql/functions.scala | 17 +---------------- 2 files changed, 1 insertion(+), 27 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index a1750d6d45d8..d9765a20a80b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -968,7 +968,6 @@ object TableFunctionRegistry { : (String, (ExpressionInfo, TableFunctionBuilder)) = { val (info, builder) = FunctionRegistryBase.build[T](name, since = None) val newBuilder = (expressions: Seq[Expression]) => { -<<<<<<< HEAD val generator = builder(expressions) assert(generator.isInstanceOf[Generator]) Generate( @@ -978,16 +977,6 @@ object TableFunctionRegistry { qualifier = None, generatorOutput = Nil, child = OneRowRelation()) -======= - try { - builder(expressions) - } catch { - case e: AnalysisException => - val argTypes = expressions.map(_.dataType.typeName).mkString(", ") - throw QueryCompilationErrors.cannotApplyTableValuedFunctionError( - name, argTypes, info.getUsage, e.getMessage) - } ->>>>>>> Revert "SPARK-41231: Adds an array_prepend function to catalyst" } (name, (info, newBuilder)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 23f78e532ec0..79e42db22776 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -4042,7 +4042,7 @@ object functions { */ def array_compact(column: Column): Column = withExpr { ArrayCompact(column.expr) - + } /** * Returns an array containing value as well as all elements from array. The new element is * positioned at the beginning of the array. @@ -4050,21 +4050,6 @@ object functions { * @group collection_funcs * @since 3.4.0 */ - def array_prepend(column: Column, element: Any): Column = withExpr { - ArrayPrepend(column.expr, lit(element).expr) - - /** - * Remove all null elements from the given array. - * - * @group collection_funcs - * @since 3.4.0 - */ - def array_compact(column: Column): Column = withExpr { - ArrayCompact(column.expr) - /** - * Returns an array containing value as well as all elements from array.The - * new element is positioned at the beginning of the array. - */ def array_prepend(column: Column, element: Any): Column = withExpr { ArrayPrepend(column.expr, lit(element).expr) } From d827d21e33e2c9dc99fa653c8550ad6e28e96330 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Tue, 7 Feb 2023 20:45:20 -0800 Subject: [PATCH 05/68] Lint --- .../expressions/CollectionExpressionsSuite.scala | 1 + .../spark/sql/DataFrameFunctionsSuite.scala | 15 ++++++++++++--- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index b443a1c7f5aa..fc769b3f1773 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} import java.time.{Duration, LocalDateTime, Period} import java.util.TimeZone + import scala.language.implicitConversions import scala.util.Random import org.apache.spark.{SparkFunSuite, SparkRuntimeException} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index d8b54dbc76ee..a4641ee646d4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2664,13 +2664,22 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Row(null, null, null))) checkAnswer( df.select(array_prepend($"a", $"d")), - Seq(Row(Seq(2, 2, 3, 4)), Row(Seq(2)), Row(null))) + Seq( + Row(Seq(2, 2, 3, 4)), + Row(Seq(2)), + Row(null))) checkAnswer( df.selectExpr("array_prepend(a, d)"), - Seq(Row(Seq(2, 2, 3, 4)), Row(Seq(2)), Row(null))) + Seq( + Row(Seq(2, 2, 3, 4)), + Row(Seq(2)), + Row(null))) checkAnswer( OneRowRelation().selectExpr("array_prepend(array(1, 2), 1.23D)"), - Seq(Row(Seq(1.23, 1.0, 2.0)))) + Seq( + Row(Seq(1.23, 1.0, 2.0)) + ) + ) checkAnswer( df.selectExpr("array_prepend(a, 1)", "array_prepend(b, \"a\")", "array_prepend(c, \"\")"), Seq( From 4e7aa1e4504b014f7b0a380a97faa4a61b6f34ce Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Tue, 7 Feb 2023 20:47:55 -0800 Subject: [PATCH 06/68] Lint --- .../sql/catalyst/expressions/CollectionExpressionsSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index fc769b3f1773..f94d216beed6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -23,6 +23,7 @@ import java.util.TimeZone import scala.language.implicitConversions import scala.util.Random + import org.apache.spark.{SparkFunSuite, SparkRuntimeException} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow From ff8c19e4cf40641fe072d00118b393065f8e5416 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Wed, 8 Feb 2023 20:46:44 -0800 Subject: [PATCH 07/68] Add examples of usage and fix test --- python/pyspark/sql/functions.py | 6 +- .../sql-functions/sql-expression-schema.md | 3 +- .../test/resources/sql-tests/inputs/array.sql | 11 +++ .../sql-tests/results/ansi/array.sql.out | 72 +++++++++++++++++++ 4 files changed, 89 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 572465ff8346..548b0266d4ef 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7618,10 +7618,12 @@ def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: return _invoke_function_over_columns("get", col, index) +@try_remote_functions def array_prepend(col: "ColumnOrName", element: Any) -> Column: """ - Collection function: Returns an array containing value as well as all elements from array. - The new element is positioned at the beginning of the array. + Collection function: Returns an array containing element as + well as all elements from array. The new element is positioned + at the beginning of the array. .. versionadded:: 3.4.0 diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 03ec4bce54b4..cf355e11fc4e 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -13,6 +13,7 @@ | org.apache.spark.sql.catalyst.expressions.ArrayAggregate | aggregate | SELECT aggregate(array(1, 2, 3), 0, (acc, x) -> acc + x) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayAggregate | reduce | SELECT reduce(array(1, 2, 3), 0, (acc, x) -> acc + x) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayAppend | array_append | SELECT array_append(array('b', 'd', 'c', 'a'), 'd') | struct> | +| org.apache.spark.sql.catalyst.expressions.ArrayPrepend | array_prepend | SELECT array_prepend(array('b', 'd', 'c', 'a'), 'd') | struct> | | org.apache.spark.sql.catalyst.expressions.ArrayCompact | array_compact | SELECT array_compact(array(1, 2, 3, null)) | struct> | | org.apache.spark.sql.catalyst.expressions.ArrayContains | array_contains | SELECT array_contains(array(1, 2, 3), 2) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayDistinct | array_distinct | SELECT array_distinct(array(1, 2, 3, null, 3)) | struct> | @@ -420,4 +421,4 @@ | org.apache.spark.sql.catalyst.expressions.xml.XPathList | xpath | SELECT xpath('b1b2b3c1c2','a/b/text()') | structb1b2b3c1c2, a/b/text()):array> | | org.apache.spark.sql.catalyst.expressions.xml.XPathLong | xpath_long | SELECT xpath_long('12', 'sum(a/b)') | struct12, sum(a/b)):bigint> | | org.apache.spark.sql.catalyst.expressions.xml.XPathShort | xpath_short | SELECT xpath_short('12', 'sum(a/b)') | struct12, sum(a/b)):smallint> | -| org.apache.spark.sql.catalyst.expressions.xml.XPathString | xpath_string | SELECT xpath_string('bcc','a/c') | structbcc, a/c):string> | \ No newline at end of file +| org.apache.spark.sql.catalyst.expressions.xml.XPathString | xpath_string | SELECT xpath_string('bcc','a/c') | structbcc, a/c):string> | diff --git a/sql/core/src/test/resources/sql-tests/inputs/array.sql b/sql/core/src/test/resources/sql-tests/inputs/array.sql index 3d107cb6dfc0..d3c36b79d1f3 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/array.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/array.sql @@ -160,3 +160,14 @@ select array_append(CAST(null AS ARRAY), CAST(null as String)); select array_append(array(), 1); select array_append(CAST(array() AS ARRAY), CAST(NULL AS String)); select array_append(array(CAST(NULL AS String)), CAST(NULL AS String)); + +-- function array_prepend +select array_prepend(array(1, 2, 3), 4); +select array_prepend(array('a', 'b', 'c'), 'd'); +select array_prepend(array(1, 2, 3, NULL), NULL); +select array_prepend(array('a', 'b', 'c', NULL), NULL); +select array_prepend(CAST(null AS ARRAY), 'a'); +select array_prepend(CAST(null AS ARRAY), CAST(null as String)); +select array_prepend(array(), 1); +select array_prepend(CAST(array() AS ARRAY), CAST(NULL AS String)); +select array_prepend(array(CAST(NULL AS String)), CAST(NULL AS String)); diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out index 0d8ef39ed60c..d228c605705d 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out @@ -784,3 +784,75 @@ select array_append(array(CAST(NULL AS String)), CAST(NULL AS String)) struct> -- !query output [null,null] + + +-- !query +select array_prepend(array(1, 2, 3), 4) +-- !query schema +struct> +-- !query output +[4,1,2,3] + + +-- !query +select array_prepend(array('a', 'b', 'c'), 'd') +-- !query schema +struct> +-- !query output +["d","a","b","c"] + + +-- !query +select array_prepend(array(1, 2, 3, NULL), NULL) +-- !query schema +struct> +-- !query output +[null,1,2,3,null] + + +-- !query +select array_prepend(array('a', 'b', 'c', NULL), NULL) +-- !query schema +struct> +-- !query output +[null,"a","b","c",null] + + +-- !query +select array_prepend(CAST(null AS ARRAY), 'a') +-- !query schema +struct> +-- !query output +NULL + + +-- !query +select array_prepend(CAST(null AS ARRAY), CAST(null as String)) +-- !query schema +struct> +-- !query output +NULL + + +-- !query +select array_prepend(array(), 1) +-- !query schema +struct> +-- !query output +[1] + + +-- !query +select array_prepend(CAST(array() AS ARRAY), CAST(NULL AS String)) +-- !query schema +struct> +-- !query output +[null] + + +-- !query +select array_prepend(array(CAST(NULL AS String)), CAST(NULL AS String)) +-- !query schema +struct> +-- !query output +[null,null] From 4f3a9685303bcaf57dcdeb4ac017b25c357064e6 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Thu, 9 Feb 2023 20:44:09 -0800 Subject: [PATCH 08/68] Fix tests --- python/pyspark/sql/functions.py | 9 +-- .../expressions/collectionOperations.scala | 3 +- .../sql-functions/sql-expression-schema.md | 2 +- .../resources/sql-tests/results/array.sql.out | 72 +++++++++++++++++++ 4 files changed, 80 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 548b0266d4ef..c8a709d27c7c 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7621,10 +7621,10 @@ def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: @try_remote_functions def array_prepend(col: "ColumnOrName", element: Any) -> Column: """ - Collection function: Returns an array containing element as - well as all elements from array. The new element is positioned + Collection function: Returns an array containing element as + well as all elements from array. The new element is positioned at the beginning of the array. - + .. versionadded:: 3.4.0 Parameters @@ -7636,6 +7636,7 @@ def array_prepend(col: "ColumnOrName", element: Any) -> Column: Returns ------- + :class:`~pyspark.sql.Column` an array excluding given value. @@ -7644,7 +7645,7 @@ def array_prepend(col: "ColumnOrName", element: Any) -> Column: >>> df = spark.createDataFrame([([2, 3, 4],), ([],)], ['data']) >>> df.select(array_prepend(df.data, 1)).collect() [Row(array_prepend(data, 1)=[1, 2, 3, 4]), Row(array_prepend(data, 1)=[1])] - """ + """ return _invoke_function("array_prepend", _to_java_column(col), element) @try_remote_functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index b122a629585b..6e2beda4bccd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1413,6 +1413,7 @@ case class ArrayContains(left: Expression, right: Expression) case class ArrayPrepend(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes + with ComplexTypeMergingExpression with QueryErrorsBase { override def nullable: Boolean = left.nullable @@ -1533,7 +1534,7 @@ case class ArrayPrepend(left: Expression, right: Expression) (left.dataType, right.dataType) match { case (_, NullType) => Seq.empty case (ArrayType(e1, hasNull), e2) => - TypeCoercion.findWiderTypeWithoutStringPromotionForTwo(e1, e2) match { + TypeCoercion.findTightestCommonType(e1, e2) match { case Some(dt) => Seq(ArrayType(dt, hasNull), dt) case _ => Seq.empty } diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index cf355e11fc4e..6146b7fcb9c0 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -13,7 +13,6 @@ | org.apache.spark.sql.catalyst.expressions.ArrayAggregate | aggregate | SELECT aggregate(array(1, 2, 3), 0, (acc, x) -> acc + x) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayAggregate | reduce | SELECT reduce(array(1, 2, 3), 0, (acc, x) -> acc + x) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayAppend | array_append | SELECT array_append(array('b', 'd', 'c', 'a'), 'd') | struct> | -| org.apache.spark.sql.catalyst.expressions.ArrayPrepend | array_prepend | SELECT array_prepend(array('b', 'd', 'c', 'a'), 'd') | struct> | | org.apache.spark.sql.catalyst.expressions.ArrayCompact | array_compact | SELECT array_compact(array(1, 2, 3, null)) | struct> | | org.apache.spark.sql.catalyst.expressions.ArrayContains | array_contains | SELECT array_contains(array(1, 2, 3), 2) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayDistinct | array_distinct | SELECT array_distinct(array(1, 2, 3, null, 3)) | struct> | @@ -27,6 +26,7 @@ | org.apache.spark.sql.catalyst.expressions.ArrayMax | array_max | SELECT array_max(array(1, 20, null, 3)) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayMin | array_min | SELECT array_min(array(1, 20, null, 3)) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayPosition | array_position | SELECT array_position(array(3, 2, 1), 1) | struct | +| org.apache.spark.sql.catalyst.expressions.ArrayPrepend | array_prepend | SELECT array_prepend(array('b', 'd', 'c', 'a'), 'd') | struct> | | org.apache.spark.sql.catalyst.expressions.ArrayRemove | array_remove | SELECT array_remove(array(1, 2, 3, null, 3), 3) | struct> | | org.apache.spark.sql.catalyst.expressions.ArrayRepeat | array_repeat | SELECT array_repeat('123', 2) | struct> | | org.apache.spark.sql.catalyst.expressions.ArraySize | array_size | SELECT array_size(array('b', 'd', 'c', 'a')) | struct | diff --git a/sql/core/src/test/resources/sql-tests/results/array.sql.out b/sql/core/src/test/resources/sql-tests/results/array.sql.out index 609122a23d31..029bd767f54c 100644 --- a/sql/core/src/test/resources/sql-tests/results/array.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/array.sql.out @@ -665,3 +665,75 @@ select array_append(array(CAST(NULL AS String)), CAST(NULL AS String)) struct> -- !query output [null,null] + + +-- !query +select array_prepend(array(1, 2, 3), 4) +-- !query schema +struct> +-- !query output +[4,1,2,3] + + +-- !query +select array_prepend(array('a', 'b', 'c'), 'd') +-- !query schema +struct> +-- !query output +["d","a","b","c"] + + +-- !query +select array_prepend(array(1, 2, 3, NULL), NULL) +-- !query schema +struct> +-- !query output +[null,1,2,3,null] + + +-- !query +select array_prepend(array('a', 'b', 'c', NULL), NULL) +-- !query schema +struct> +-- !query output +[null,"a","b","c",null] + + +-- !query +select array_prepend(CAST(null AS ARRAY), 'a') +-- !query schema +struct> +-- !query output +NULL + + +-- !query +select array_prepend(CAST(null AS ARRAY), CAST(null as String)) +-- !query schema +struct> +-- !query output +NULL + + +-- !query +select array_prepend(array(), 1) +-- !query schema +struct> +-- !query output +[1] + + +-- !query +select array_prepend(CAST(array() AS ARRAY), CAST(NULL AS String)) +-- !query schema +struct> +-- !query output +[null] + + +-- !query +select array_prepend(array(CAST(NULL AS String)), CAST(NULL AS String)) +-- !query schema +struct> +-- !query output +[null,null] From ec9ea76e9cfa7bf78bc277f447fdd6c9cb95f3e4 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Fri, 10 Feb 2023 22:03:34 -0800 Subject: [PATCH 09/68] Fix types --- python/pyspark/sql/functions.py | 2 +- .../spark/sql/catalyst/expressions/collectionOperations.scala | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index c8a709d27c7c..f9230f5478eb 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7624,7 +7624,7 @@ def array_prepend(col: "ColumnOrName", element: Any) -> Column: Collection function: Returns an array containing element as well as all elements from array. The new element is positioned at the beginning of the array. - + .. versionadded:: 3.4.0 Parameters diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 6e2beda4bccd..73be6327bca1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1532,7 +1532,6 @@ case class ArrayPrepend(left: Expression, right: Expression) } override def inputTypes: Seq[AbstractDataType] = { (left.dataType, right.dataType) match { - case (_, NullType) => Seq.empty case (ArrayType(e1, hasNull), e2) => TypeCoercion.findTightestCommonType(e1, e2) match { case Some(dt) => Seq(ArrayType(dt, hasNull), dt) From c728581bfad8e36b29d6bf74fc40a8ea8a3f5c6c Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Sun, 26 Feb 2023 16:46:34 -0800 Subject: [PATCH 10/68] Fix tests --- python/pyspark/sql/functions.py | 1 - .../spark/sql/catalyst/expressions/collectionOperations.scala | 4 ++-- .../scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index f9230f5478eb..294ec3669a98 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7636,7 +7636,6 @@ def array_prepend(col: "ColumnOrName", element: Any) -> Column: Returns ------- - :class:`~pyspark.sql.Column` an array excluding given value. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 73be6327bca1..068d18f1727b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1405,8 +1405,8 @@ case class ArrayContains(left: Expression, right: Expression) "_FUNC_(array, value) - Returns an array containing value as well as all elements from array. The new element is positioned at the beginning of the array.", examples = """ Examples: - > SELECT _FUNC_(array(1, 2, 3), 4); - [4, 1, 2, 3] + > SELECT _FUNC_(array('b', 'd', 'c', 'a'), 'd'); + ["d","b","d","c","a"] """, group = "array_funcs", since = "3.4.0") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index e34290396212..4fd350d8db26 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2692,7 +2692,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { }, errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( - "paramIndex" -> "1", + "paramIndex" -> "0", "sqlExpr" -> "\"array_prepend(_1, _2)\"", "inputSql" -> "\"_1\"", "inputType" -> "\"STRING\"", From 6eba188c53abe689df84af448f0626679ed73708 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Sun, 26 Feb 2023 19:18:02 -0800 Subject: [PATCH 11/68] Fix python linter --- python/pyspark/sql/functions.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 294ec3669a98..915470b06ca3 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7618,6 +7618,7 @@ def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: return _invoke_function_over_columns("get", col, index) + @try_remote_functions def array_prepend(col: "ColumnOrName", element: Any) -> Column: """ @@ -7644,9 +7645,10 @@ def array_prepend(col: "ColumnOrName", element: Any) -> Column: >>> df = spark.createDataFrame([([2, 3, 4],), ([],)], ['data']) >>> df.select(array_prepend(df.data, 1)).collect() [Row(array_prepend(data, 1)=[1, 2, 3, 4]), Row(array_prepend(data, 1)=[1])] - """ + """ return _invoke_function("array_prepend", _to_java_column(col), element) + @try_remote_functions def array_remove(col: "ColumnOrName", element: Any) -> Column: """ @@ -7677,6 +7679,7 @@ def array_remove(col: "ColumnOrName", element: Any) -> Column: """ return _invoke_function("array_remove", _to_java_column(col), element) + @try_remote_functions def array_distinct(col: "ColumnOrName") -> Column: """ From 63ff6cc605719bcbde4b9f30bd484b7c9e3ed575 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Sun, 26 Feb 2023 20:30:18 -0800 Subject: [PATCH 12/68] Add test for null cases --- .../expressions/collectionOperations.scala | 28 ++++++++----------- .../spark/sql/DataFrameFunctionsSuite.scala | 7 +++++ 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 068d18f1727b..0a4680193e01 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1418,6 +1418,9 @@ case class ArrayPrepend(left: Expression, right: Expression) override def nullable: Boolean = left.nullable + @transient protected lazy val elementType: DataType = + inputTypes.head.asInstanceOf[ArrayType].elementType + override def eval(input: InternalRow): Any = { val value1 = left.eval(input) if (value1 == null) { @@ -1427,23 +1430,16 @@ case class ArrayPrepend(left: Expression, right: Expression) nullSafeEval(value1, value2) } } - override def nullSafeEval(arr: Any, value: Any): Any = { - val numberOfElements = arr.asInstanceOf[ArrayData].numElements() - if (numberOfElements + 1 > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + override def nullSafeEval(arr: Any, elementData: Any): Any = { + val arrayData = arr.asInstanceOf[ArrayData] + val numberOfElements = arrayData.numElements() + 1 + if (numberOfElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { throw QueryExecutionErrors.concatArraysWithElementsExceedLimitError(numberOfElements) } - val newArray = new Array[Any](numberOfElements + 1) - newArray(0) = value - var pos = 1 - arr - .asInstanceOf[ArrayData] - .foreach( - right.dataType, - (i, v) => { - newArray(pos) = v - pos += 1 - }) - new GenericArrayData(newArray) + val finalData = new Array[Any](numberOfElements) + finalData.update(0, elementData) + arrayData.foreach(elementType, (i: Int, v: Any) => finalData.update(i + 1, v)) + new GenericArrayData(finalData) } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val leftGen = left.genCode(ctx) @@ -1505,7 +1501,7 @@ case class ArrayPrepend(left: Expression, right: Expression) newLeft: Expression, newRight: Expression): ArrayPrepend = copy(left = newLeft, right = newRight) - override def dataType: DataType = left.dataType + override def dataType: DataType = if (right.nullable) left.dataType.asNullable else left.dataType override def checkInputDataTypes(): TypeCheckResult = { (left.dataType, right.dataType) match { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 4fd350d8db26..bc096f923fa2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2710,6 +2710,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "leftType" -> "\"ARRAY\"", "rightType" -> "\"STRING\""), queryContext = Array(ExpectedContext("", "", 0, 30, "array_prepend(array(1, 2), '1')"))) + val df2 = Seq((Array[String]("a", "b", "c"), "d"), + (null, "d"), + (Array[String]("x", "y", "z"), null), + (null, null) + ).toDF("a", "b") + checkAnswer(df2.selectExpr("array_prepend(a, b)"), + Seq(Row(Seq("d", "a", "b", "c")), Row(null), Row(Seq(null, "x", "y", "z")), Row(null))) } test("array remove") { From 73a7dd78550b20321a0c6313c3b6e651848ae176 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Sun, 26 Feb 2023 20:35:43 -0800 Subject: [PATCH 13/68] Fix type of array --- .../sql/catalyst/expressions/collectionOperations.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 0a4680193e01..d27f3d3f7851 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1451,13 +1451,13 @@ case class ArrayPrepend(left: Expression, right: Expression) val pos = ctx.freshName("pos") val allocation = CodeGenerator.createArrayData( newArray, - right.dataType, + elementType, newArraySize, s" $prettyName failed.") val assignment = - CodeGenerator.createArrayAssignment(newArray, right.dataType, arr, pos, i, false) + CodeGenerator.createArrayAssignment(newArray, elementType, arr, pos, i, false) val newElemAssignment = - CodeGenerator.setArrayElement(newArray, right.dataType, pos, value, Some(rightGen.isNull)) + CodeGenerator.setArrayElement(newArray, elementType, pos, value, Some(rightGen.isNull)) s""" |int $pos = 0; |int $newArraySize = $arr.numElements() + 1; From 6f97761ed7596c66f943737bee22cacd5200fd95 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Fri, 27 Jan 2023 20:34:49 -0800 Subject: [PATCH 14/68] Adds a array_prepend expression to catalyst --- .../reference/pyspark.sql/functions.rst | 1 + python/pyspark/sql/functions.py | 27 ++++- .../catalyst/analysis/FunctionRegistry.scala | 12 ++ .../expressions/collectionOperations.scala | 113 ++++++++++++++++++ .../CollectionExpressionsSuite.scala | 41 +++++++ .../org/apache/spark/sql/functions.scala | 10 ++ .../spark/sql/DataFrameFunctionsSuite.scala | 52 ++++++++ 7 files changed, 255 insertions(+), 1 deletion(-) diff --git a/python/docs/source/reference/pyspark.sql/functions.rst b/python/docs/source/reference/pyspark.sql/functions.rst index 70fc04ef9cf2..cbc46e1fae18 100644 --- a/python/docs/source/reference/pyspark.sql/functions.rst +++ b/python/docs/source/reference/pyspark.sql/functions.rst @@ -159,6 +159,7 @@ Collection Functions array_sort array_insert array_remove + array_prepend array_distinct array_intersect array_union diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 8bee517de6af..572465ff8346 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7618,6 +7618,32 @@ def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: return _invoke_function_over_columns("get", col, index) +def array_prepend(col: "ColumnOrName", element: Any) -> Column: + """ + Collection function: Returns an array containing value as well as all elements from array. + The new element is positioned at the beginning of the array. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + name of column containing array + element : + element to be prepended to the array + + Returns + ------- + :class:`~pyspark.sql.Column` + an array excluding given value. + + Examples + -------- + >>> df = spark.createDataFrame([([2, 3, 4],), ([],)], ['data']) + >>> df.select(array_prepend(df.data, 1)).collect() + [Row(array_prepend(data, 1)=[1, 2, 3, 4]), Row(array_prepend(data, 1)=[1])] + """ + return _invoke_function("array_prepend", _to_java_column(col), element) @try_remote_functions def array_remove(col: "ColumnOrName", element: Any) -> Column: @@ -7649,7 +7675,6 @@ def array_remove(col: "ColumnOrName", element: Any) -> Column: """ return _invoke_function("array_remove", _to_java_column(col), element) - @try_remote_functions def array_distinct(col: "ColumnOrName") -> Column: """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index d87cc0126cfa..ce9e58722a2c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -696,6 +696,7 @@ object FunctionRegistry { expression[Sequence]("sequence"), expression[ArrayRepeat]("array_repeat"), expression[ArrayRemove]("array_remove"), + expression[ArrayPrepend]("array_prepend"), expression[ArrayDistinct]("array_distinct"), expression[ArrayTransform]("transform"), expression[MapFilter]("map_filter"), @@ -969,6 +970,7 @@ object TableFunctionRegistry { : (String, (ExpressionInfo, TableFunctionBuilder)) = { val (info, builder) = FunctionRegistryBase.build[T](name, since = None) val newBuilder = (expressions: Seq[Expression]) => { +<<<<<<< HEAD val generator = builder(expressions) assert(generator.isInstanceOf[Generator]) Generate( @@ -978,6 +980,16 @@ object TableFunctionRegistry { qualifier = None, generatorOutput = Nil, child = OneRowRelation()) +======= + try { + builder(expressions) + } catch { + case e: AnalysisException => + val argTypes = expressions.map(_.dataType.typeName).mkString(", ") + throw QueryCompilationErrors.cannotApplyTableValuedFunctionError( + name, argTypes, info.getUsage, e.getMessage) + } +>>>>>>> Revert "SPARK-41231: Adds an array_prepend function to catalyst" } (name, (info, newBuilder)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 92a3127d438a..226b8fcdddd6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1399,6 +1399,119 @@ case class ArrayContains(left: Expression, right: Expression) copy(left = newLeft, right = newRight) } +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = + "_FUNC_(array, value) - Returns an array containing value as well as all elements from array. The new element is positioned at the beginning of the array.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), 4); + [4, 1, 2, 3] + """, + group = "array_funcs", + since = "3.4.0") +case class ArrayPrepend(left: Expression, right: Expression) + extends BinaryExpression + with ImplicitCastInputTypes + with NullIntolerant + with QueryErrorsBase { + + override def nullSafeEval(arr: Any, value: Any): Any = { + val numberOfElements = arr.asInstanceOf[ArrayData].numElements() + if (numberOfElements + 1 > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + throw QueryExecutionErrors.concatArraysWithElementsExceedLimitError(numberOfElements) + } + val newArray = new Array[Any](numberOfElements + 1) + newArray(0) = value + var pos = 1 + arr + .asInstanceOf[ArrayData] + .foreach( + right.dataType, + (i, v) => { + newArray(pos) = v + pos += 1 + }) + new GenericArrayData(newArray) + } + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen( + ctx, + ev, + (arr, value) => { + val newArraySize = ctx.freshName("newArraySize") + val newArray = ctx.freshName("newArray") + val i = ctx.freshName("i") + val pos = ctx.freshName("pos") + val allocation = CodeGenerator.createArrayData( + newArray, + right.dataType, + newArraySize, + s" $prettyName failed.") + val assignment = + CodeGenerator.createArrayAssignment(newArray, right.dataType, arr, pos, i, false) + val newElemAssignment = + CodeGenerator.setArrayElement(newArray, right.dataType, pos, value) + s""" + |int $pos = 0; + |int $newArraySize = $arr.numElements() + 1; + |$allocation + |$newElemAssignment + |$pos = $pos + 1; + |for (int $i = 0; $i < $arr.numElements(); $i ++) { + | $assignment + | $pos = $pos + 1; + |} + |${ev.value} = $newArray; + |""".stripMargin + }) + } + + override def prettyName: String = "array_prepend" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): ArrayPrepend = + copy(left = newLeft, right = newRight) + override def dataType: DataType = left.dataType + override def checkInputDataTypes(): TypeCheckResult = { + (left.dataType, right.dataType) match { + case (_, NullType) | (NullType, _) => + DataTypeMismatch( + errorSubClass = "NULL_TYPE", + messageParameters = Map("functionName" -> toSQLId(prettyName))) + case (l, _) if !ArrayType.acceptsType(l) => + DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> "1", + "requiredType" -> toSQLType(ArrayType), + "inputSql" -> toSQLExpr(left), + "inputType" -> toSQLType(left.dataType))) + case (ArrayType(e1, _), e2) if e1.sameType(e2) => + TypeUtils.checkForOrderingExpr(e2, prettyName) + case _ => + DataTypeMismatch( + errorSubClass = "ARRAY_FUNCTION_DIFF_TYPES", + messageParameters = Map( + "functionName" -> toSQLId(prettyName), + "dataType" -> toSQLType(ArrayType), + "leftType" -> toSQLType(left.dataType), + "rightType" -> toSQLType(right.dataType))) + } + } + override def inputTypes: Seq[AbstractDataType] = { + (left.dataType, right.dataType) match { + case (_, NullType) => Seq.empty + case (ArrayType(e1, hasNull), e2) => + TypeCoercion.findWiderTypeWithoutStringPromotionForTwo(e1, e2) match { + case Some(dt) => Seq(ArrayType(dt, hasNull), dt) + case _ => Seq.empty + } + case _ => Seq.empty + } + } +} + /** * Checks if the two arrays contain at least one common element. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 9b97430594d0..56472a553af2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1840,6 +1840,47 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayRepeat(Literal("hi"), Literal(null, IntegerType)), null) } + test("SPARK-41233: ArrayPrepend") { + val a0 = Literal.create(Seq(1, 2, 3, 4), ArrayType(IntegerType)) + val a1 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType)) + val a2 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType)) + val a3 = Literal.create(null, ArrayType(StringType)) + + checkEvaluation(ArrayPrepend(a0, Literal(0)), Seq(0, 1, 2, 3, 4)) + checkEvaluation(ArrayPrepend(a1, Literal("a")), Seq("a", "a", "b", "c")) + checkEvaluation(ArrayPrepend(a2, Literal(1)), Seq(1)) + checkEvaluation(ArrayPrepend(a2, Literal(null, IntegerType)), null) + checkEvaluation(ArrayPrepend(a3, Literal("a")), null) + checkEvaluation(ArrayPrepend(a3, Literal(null, StringType)), null) + + // complex data types + val b0 = Literal.create( + Seq[Array[Byte]]( + Array[Byte](5, 6), + Array[Byte](1, 2), + Array[Byte](1, 2), + Array[Byte](5, 6)), + ArrayType(BinaryType)) + val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), null), ArrayType(BinaryType)) + val nullBinary = Literal.create(null, BinaryType) + checkEvaluation(ArrayPrepend(b0, nullBinary), null) + val dataToPrepend1 = Literal.create(Array[Byte](5, 6), BinaryType) + checkEvaluation( + ArrayPrepend(b1, dataToPrepend1), + Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](2, 1), null)) + + val c0 = Literal.create( + Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)), + ArrayType(ArrayType(IntegerType))) + val dataToPrepend2 = Literal.create(Seq[Int](5, 6), ArrayType(IntegerType)) + checkEvaluation( + ArrayPrepend(c0, dataToPrepend2), + Seq(Seq[Int](5, 6), Seq[Int](1, 2), Seq[Int](3, 4))) + checkEvaluation( + ArrayPrepend(c0, Literal.create(Seq.empty[Int], ArrayType(IntegerType))), + Seq(Seq.empty[Int], Seq[Int](1, 2), Seq[Int](3, 4))) + } + test("Array remove") { val a0 = Literal.create(Seq(1, 2, 3, 2, 2, 5), ArrayType(IntegerType)) val a1 = Literal.create(Seq("b", "a", "a", "c", "b"), ArrayType(StringType)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index cb5c1ad5c495..d2f1df8780b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -4042,6 +4042,16 @@ object functions { */ def array_compact(column: Column): Column = withExpr { ArrayCompact(column.expr) + + /** + * Returns an array containing value as well as all elements from array. The new element is + * positioned at the beginning of the array. + * + * @group collection_funcs + * @since 3.4.0 + */ + def array_prepend(column: Column, element: Any): Column = withExpr { + ArrayPrepend(column.expr, lit(element).expr) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 6ed8299976c0..f31278bae006 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2651,6 +2651,58 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { ) } + test("SPARK-41233: array prepend") { + val df = Seq( + (Array[Int](2, 3, 4), Array("b", "c", "d"), Array("", ""), 2), + (Array.empty[Int], Array.empty[String], Array.empty[String], 2), + (null, null, null, 2)).toDF("a", "b", "c", "d") + checkAnswer( + df.select(array_prepend($"a", 1), array_prepend($"b", "a"), array_prepend($"c", "")), + Seq( + Row(Seq(1, 2, 3, 4), Seq("a", "b", "c", "d"), Seq("", "", "")), + Row(Seq(1), Seq("a"), Seq("")), + Row(null, null, null))) + checkAnswer( + df.select(array_prepend($"a", $"d")), + Seq(Row(Seq(2, 2, 3, 4)), Row(Seq(2)), Row(null))) + checkAnswer( + df.selectExpr("array_prepend(a, d)"), + Seq(Row(Seq(2, 2, 3, 4)), Row(Seq(2)), Row(null))) + checkAnswer( + OneRowRelation().selectExpr("array_prepend(array(1, 2), 1.23D)"), + Seq(Row(Seq(1.23, 1.0, 2.0)))) + checkAnswer( + df.selectExpr("array_prepend(a, 1)", "array_prepend(b, \"a\")", "array_prepend(c, \"\")"), + Seq( + Row(Seq(1, 2, 3, 4), Seq("a", "b", "c", "d"), Seq("", "", "")), + Row(Seq(1), Seq("a"), Seq("")), + Row(null, null, null))) + checkError( + exception = intercept[AnalysisException] { + Seq(("a string element", "a")).toDF().selectExpr("array_prepend(_1, _2)") + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + parameters = Map( + "paramIndex" -> "1", + "sqlExpr" -> "\"array_prepend(_1, _2)\"", + "inputSql" -> "\"_1\"", + "inputType" -> "\"STRING\"", + "requiredType" -> "\"ARRAY\""), + queryContext = Array(ExpectedContext("", "", 0, 20, "array_prepend(_1, _2)"))) + checkError( + exception = intercept[AnalysisException] { + OneRowRelation().selectExpr("array_prepend(array(1, 2), '1')") + }, + errorClass = "DATATYPE_MISMATCH.ARRAY_FUNCTION_DIFF_TYPES", + parameters = Map( + "sqlExpr" -> "\"array_prepend(array(1, 2), 1)\"", + "functionName" -> "`array_prepend`", + "dataType" -> "\"ARRAY\"", + "leftType" -> "\"ARRAY\"", + "rightType" -> "\"STRING\""), + queryContext = Array(ExpectedContext("", "", 0, 30, "array_prepend(array(1, 2), '1')"))) + } + test("array remove") { val df = Seq( (Array[Int](2, 1, 2, 3), Array("a", "b", "c", "a"), Array("", ""), 2), From fec08e9fd1ae7aa2ce08e6ae4a054fadfb1143ad Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Fri, 27 Jan 2023 20:36:47 -0800 Subject: [PATCH 15/68] Fix null handling --- .../expressions/collectionOperations.scala | 122 +++++++++++------- .../CollectionExpressionsSuite.scala | 19 +-- .../org/apache/spark/sql/functions.scala | 15 +++ 3 files changed, 101 insertions(+), 55 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 226b8fcdddd6..b122a629585b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1413,9 +1413,19 @@ case class ArrayContains(left: Expression, right: Expression) case class ArrayPrepend(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes - with NullIntolerant with QueryErrorsBase { + override def nullable: Boolean = left.nullable + + override def eval(input: InternalRow): Any = { + val value1 = left.eval(input) + if (value1 == null) { + null + } else { + val value2 = right.eval(input) + nullSafeEval(value1, value2) + } + } override def nullSafeEval(arr: Any, value: Any): Any = { val numberOfElements = arr.asInstanceOf[ArrayData].numElements() if (numberOfElements + 1 > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { @@ -1435,36 +1445,57 @@ case class ArrayPrepend(left: Expression, right: Expression) new GenericArrayData(newArray) } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen( - ctx, - ev, - (arr, value) => { - val newArraySize = ctx.freshName("newArraySize") - val newArray = ctx.freshName("newArray") - val i = ctx.freshName("i") - val pos = ctx.freshName("pos") - val allocation = CodeGenerator.createArrayData( - newArray, - right.dataType, - newArraySize, - s" $prettyName failed.") - val assignment = - CodeGenerator.createArrayAssignment(newArray, right.dataType, arr, pos, i, false) - val newElemAssignment = - CodeGenerator.setArrayElement(newArray, right.dataType, pos, value) + val leftGen = left.genCode(ctx) + val rightGen = right.genCode(ctx) + val f = (arr: String, value: String) => { + val newArraySize = ctx.freshName("newArraySize") + val newArray = ctx.freshName("newArray") + val i = ctx.freshName("i") + val pos = ctx.freshName("pos") + val allocation = CodeGenerator.createArrayData( + newArray, + right.dataType, + newArraySize, + s" $prettyName failed.") + val assignment = + CodeGenerator.createArrayAssignment(newArray, right.dataType, arr, pos, i, false) + val newElemAssignment = + CodeGenerator.setArrayElement(newArray, right.dataType, pos, value, Some(rightGen.isNull)) + s""" + |int $pos = 0; + |int $newArraySize = $arr.numElements() + 1; + |$allocation + |$newElemAssignment + |$pos = $pos + 1; + |for (int $i = 0; $i < $arr.numElements(); $i ++) { + | $assignment + | $pos = $pos + 1; + |} + |${ev.value} = $newArray; + |""".stripMargin + } + val resultCode = f(leftGen.value, rightGen.value) + if(nullable) { + val nullSafeEval = leftGen.code + rightGen.code + ctx.nullSafeExec(nullable, leftGen.isNull) { s""" - |int $pos = 0; - |int $newArraySize = $arr.numElements() + 1; - |$allocation - |$newElemAssignment - |$pos = $pos + 1; - |for (int $i = 0; $i < $arr.numElements(); $i ++) { - | $assignment - | $pos = $pos + 1; - |} - |${ev.value} = $newArray; + |${ev.isNull} = false; + |${resultCode} |""".stripMargin - }) + } + ev.copy(code = + code""" + boolean ${ev.isNull} = true; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + $nullSafeEval + """) + } else { + ev.copy(code = + code""" + ${leftGen.code} + ${rightGen.code} + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + $resultCode""", isNull = FalseLiteral) + } } override def prettyName: String = "array_prepend" @@ -1472,31 +1503,30 @@ case class ArrayPrepend(left: Expression, right: Expression) override protected def withNewChildrenInternal( newLeft: Expression, newRight: Expression): ArrayPrepend = copy(left = newLeft, right = newRight) + override def dataType: DataType = left.dataType + override def checkInputDataTypes(): TypeCheckResult = { (left.dataType, right.dataType) match { - case (_, NullType) | (NullType, _) => - DataTypeMismatch( - errorSubClass = "NULL_TYPE", - messageParameters = Map("functionName" -> toSQLId(prettyName))) - case (l, _) if !ArrayType.acceptsType(l) => + case (ArrayType(e1, _), e2) if e1.sameType(e2) => TypeCheckResult.TypeCheckSuccess + case (ArrayType(e1, _), e2) => DataTypeMismatch( + errorSubClass = "ARRAY_FUNCTION_DIFF_TYPES", + messageParameters = Map( + "functionName" -> toSQLId(prettyName), + "leftType" -> toSQLType(left.dataType), + "rightType" -> toSQLType(right.dataType), + "dataType" -> toSQLType(ArrayType) + )) + case _ => DataTypeMismatch( errorSubClass = "UNEXPECTED_INPUT_TYPE", messageParameters = Map( - "paramIndex" -> "1", + "paramIndex" -> "0", "requiredType" -> toSQLType(ArrayType), "inputSql" -> toSQLExpr(left), - "inputType" -> toSQLType(left.dataType))) - case (ArrayType(e1, _), e2) if e1.sameType(e2) => - TypeUtils.checkForOrderingExpr(e2, prettyName) - case _ => - DataTypeMismatch( - errorSubClass = "ARRAY_FUNCTION_DIFF_TYPES", - messageParameters = Map( - "functionName" -> toSQLId(prettyName), - "dataType" -> toSQLType(ArrayType), - "leftType" -> toSQLType(left.dataType), - "rightType" -> toSQLType(right.dataType))) + "inputType" -> toSQLType(left.dataType) + ) + ) } } override def inputTypes: Seq[AbstractDataType] = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 56472a553af2..dc8cc44a6535 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -20,10 +20,8 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} import java.time.{Duration, LocalDateTime, Period} import java.util.TimeZone - import scala.language.implicitConversions import scala.util.Random - import org.apache.spark.{SparkFunSuite, SparkRuntimeException} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow @@ -1849,21 +1847,24 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayPrepend(a0, Literal(0)), Seq(0, 1, 2, 3, 4)) checkEvaluation(ArrayPrepend(a1, Literal("a")), Seq("a", "a", "b", "c")) checkEvaluation(ArrayPrepend(a2, Literal(1)), Seq(1)) - checkEvaluation(ArrayPrepend(a2, Literal(null, IntegerType)), null) + checkEvaluation(ArrayPrepend(a2, Literal(null, IntegerType)), Seq(null)) checkEvaluation(ArrayPrepend(a3, Literal("a")), null) checkEvaluation(ArrayPrepend(a3, Literal(null, StringType)), null) // complex data types + val data = Seq[Array[Byte]]( + Array[Byte](5, 6), + Array[Byte](1, 2), + Array[Byte](1, 2), + Array[Byte](5, 6)) val b0 = Literal.create( - Seq[Array[Byte]]( - Array[Byte](5, 6), - Array[Byte](1, 2), - Array[Byte](1, 2), - Array[Byte](5, 6)), + data, ArrayType(BinaryType)) val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), null), ArrayType(BinaryType)) val nullBinary = Literal.create(null, BinaryType) - checkEvaluation(ArrayPrepend(b0, nullBinary), null) + // Calling ArrayPrepend with a null element should result in NULL being prepended to the array + val dataWithNullPrepended = null +: data + checkEvaluation(ArrayPrepend(b0, nullBinary), dataWithNullPrepended) val dataToPrepend1 = Literal.create(Array[Byte](5, 6), BinaryType) checkEvaluation( ArrayPrepend(b1, dataToPrepend1), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index d2f1df8780b3..1f66a5daa2d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -4050,6 +4050,21 @@ object functions { * @group collection_funcs * @since 3.4.0 */ + def array_prepend(column: Column, element: Any): Column = withExpr { + ArrayPrepend(column.expr, lit(element).expr) + + /** + * Remove all null elements from the given array. + * + * @group collection_funcs + * @since 3.4.0 + */ + def array_compact(column: Column): Column = withExpr { + ArrayCompact(column.expr) + /** + * Returns an array containing value as well as all elements from array.The + * new element is positioned at the beginning of the array. + */ def array_prepend(column: Column, element: Any): Column = withExpr { ArrayPrepend(column.expr, lit(element).expr) } From a8da3455a17bf22cb1b6695c671117599521faa7 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Fri, 27 Jan 2023 20:41:01 -0800 Subject: [PATCH 16/68] Fix --- .../scala/org/apache/spark/sql/functions.scala | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 1f66a5daa2d2..069e7b79fcb7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -4042,7 +4042,7 @@ object functions { */ def array_compact(column: Column): Column = withExpr { ArrayCompact(column.expr) - + } /** * Returns an array containing value as well as all elements from array. The new element is * positioned at the beginning of the array. @@ -4050,21 +4050,6 @@ object functions { * @group collection_funcs * @since 3.4.0 */ - def array_prepend(column: Column, element: Any): Column = withExpr { - ArrayPrepend(column.expr, lit(element).expr) - - /** - * Remove all null elements from the given array. - * - * @group collection_funcs - * @since 3.4.0 - */ - def array_compact(column: Column): Column = withExpr { - ArrayCompact(column.expr) - /** - * Returns an array containing value as well as all elements from array.The - * new element is positioned at the beginning of the array. - */ def array_prepend(column: Column, element: Any): Column = withExpr { ArrayPrepend(column.expr, lit(element).expr) } From 3af8fd96c156fe8c25686cedf250b55a7a07ef90 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Fri, 27 Jan 2023 20:41:01 -0800 Subject: [PATCH 17/68] Fix --- .../sql/catalyst/analysis/FunctionRegistry.scala | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index ce9e58722a2c..472396bbef22 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -970,7 +970,6 @@ object TableFunctionRegistry { : (String, (ExpressionInfo, TableFunctionBuilder)) = { val (info, builder) = FunctionRegistryBase.build[T](name, since = None) val newBuilder = (expressions: Seq[Expression]) => { -<<<<<<< HEAD val generator = builder(expressions) assert(generator.isInstanceOf[Generator]) Generate( @@ -980,16 +979,6 @@ object TableFunctionRegistry { qualifier = None, generatorOutput = Nil, child = OneRowRelation()) -======= - try { - builder(expressions) - } catch { - case e: AnalysisException => - val argTypes = expressions.map(_.dataType.typeName).mkString(", ") - throw QueryCompilationErrors.cannotApplyTableValuedFunctionError( - name, argTypes, info.getUsage, e.getMessage) - } ->>>>>>> Revert "SPARK-41231: Adds an array_prepend function to catalyst" } (name, (info, newBuilder)) } From 2cd3e180e24239075b8469877924f2d8040e31e1 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Tue, 7 Feb 2023 20:45:20 -0800 Subject: [PATCH 18/68] Lint --- .../expressions/CollectionExpressionsSuite.scala | 1 + .../spark/sql/DataFrameFunctionsSuite.scala | 15 ++++++++++++--- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index dc8cc44a6535..9ace0cbf854b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} import java.time.{Duration, LocalDateTime, Period} import java.util.TimeZone + import scala.language.implicitConversions import scala.util.Random import org.apache.spark.{SparkFunSuite, SparkRuntimeException} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index f31278bae006..e34290396212 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2664,13 +2664,22 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Row(null, null, null))) checkAnswer( df.select(array_prepend($"a", $"d")), - Seq(Row(Seq(2, 2, 3, 4)), Row(Seq(2)), Row(null))) + Seq( + Row(Seq(2, 2, 3, 4)), + Row(Seq(2)), + Row(null))) checkAnswer( df.selectExpr("array_prepend(a, d)"), - Seq(Row(Seq(2, 2, 3, 4)), Row(Seq(2)), Row(null))) + Seq( + Row(Seq(2, 2, 3, 4)), + Row(Seq(2)), + Row(null))) checkAnswer( OneRowRelation().selectExpr("array_prepend(array(1, 2), 1.23D)"), - Seq(Row(Seq(1.23, 1.0, 2.0)))) + Seq( + Row(Seq(1.23, 1.0, 2.0)) + ) + ) checkAnswer( df.selectExpr("array_prepend(a, 1)", "array_prepend(b, \"a\")", "array_prepend(c, \"\")"), Seq( From a307cb401887ac43fd78083ff035179552ab7e32 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Tue, 7 Feb 2023 20:47:55 -0800 Subject: [PATCH 19/68] Lint --- .../sql/catalyst/expressions/CollectionExpressionsSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 9ace0cbf854b..667f717cce77 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -23,6 +23,7 @@ import java.util.TimeZone import scala.language.implicitConversions import scala.util.Random + import org.apache.spark.{SparkFunSuite, SparkRuntimeException} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow From 7b2450011791c8037f320ae6c1341d7f52ecff44 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Wed, 8 Feb 2023 20:46:44 -0800 Subject: [PATCH 20/68] Add examples of usage and fix test --- python/pyspark/sql/functions.py | 6 +- .../sql-functions/sql-expression-schema.md | 3 +- .../test/resources/sql-tests/inputs/array.sql | 11 +++ .../sql-tests/results/ansi/array.sql.out | 72 +++++++++++++++++++ 4 files changed, 89 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 572465ff8346..548b0266d4ef 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7618,10 +7618,12 @@ def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: return _invoke_function_over_columns("get", col, index) +@try_remote_functions def array_prepend(col: "ColumnOrName", element: Any) -> Column: """ - Collection function: Returns an array containing value as well as all elements from array. - The new element is positioned at the beginning of the array. + Collection function: Returns an array containing element as + well as all elements from array. The new element is positioned + at the beginning of the array. .. versionadded:: 3.4.0 diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 03ec4bce54b4..cf355e11fc4e 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -13,6 +13,7 @@ | org.apache.spark.sql.catalyst.expressions.ArrayAggregate | aggregate | SELECT aggregate(array(1, 2, 3), 0, (acc, x) -> acc + x) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayAggregate | reduce | SELECT reduce(array(1, 2, 3), 0, (acc, x) -> acc + x) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayAppend | array_append | SELECT array_append(array('b', 'd', 'c', 'a'), 'd') | struct> | +| org.apache.spark.sql.catalyst.expressions.ArrayPrepend | array_prepend | SELECT array_prepend(array('b', 'd', 'c', 'a'), 'd') | struct> | | org.apache.spark.sql.catalyst.expressions.ArrayCompact | array_compact | SELECT array_compact(array(1, 2, 3, null)) | struct> | | org.apache.spark.sql.catalyst.expressions.ArrayContains | array_contains | SELECT array_contains(array(1, 2, 3), 2) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayDistinct | array_distinct | SELECT array_distinct(array(1, 2, 3, null, 3)) | struct> | @@ -420,4 +421,4 @@ | org.apache.spark.sql.catalyst.expressions.xml.XPathList | xpath | SELECT xpath('b1b2b3c1c2','a/b/text()') | structb1b2b3c1c2, a/b/text()):array> | | org.apache.spark.sql.catalyst.expressions.xml.XPathLong | xpath_long | SELECT xpath_long('12', 'sum(a/b)') | struct12, sum(a/b)):bigint> | | org.apache.spark.sql.catalyst.expressions.xml.XPathShort | xpath_short | SELECT xpath_short('12', 'sum(a/b)') | struct12, sum(a/b)):smallint> | -| org.apache.spark.sql.catalyst.expressions.xml.XPathString | xpath_string | SELECT xpath_string('bcc','a/c') | structbcc, a/c):string> | \ No newline at end of file +| org.apache.spark.sql.catalyst.expressions.xml.XPathString | xpath_string | SELECT xpath_string('bcc','a/c') | structbcc, a/c):string> | diff --git a/sql/core/src/test/resources/sql-tests/inputs/array.sql b/sql/core/src/test/resources/sql-tests/inputs/array.sql index 3d107cb6dfc0..d3c36b79d1f3 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/array.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/array.sql @@ -160,3 +160,14 @@ select array_append(CAST(null AS ARRAY), CAST(null as String)); select array_append(array(), 1); select array_append(CAST(array() AS ARRAY), CAST(NULL AS String)); select array_append(array(CAST(NULL AS String)), CAST(NULL AS String)); + +-- function array_prepend +select array_prepend(array(1, 2, 3), 4); +select array_prepend(array('a', 'b', 'c'), 'd'); +select array_prepend(array(1, 2, 3, NULL), NULL); +select array_prepend(array('a', 'b', 'c', NULL), NULL); +select array_prepend(CAST(null AS ARRAY), 'a'); +select array_prepend(CAST(null AS ARRAY), CAST(null as String)); +select array_prepend(array(), 1); +select array_prepend(CAST(array() AS ARRAY), CAST(NULL AS String)); +select array_prepend(array(CAST(NULL AS String)), CAST(NULL AS String)); diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out index 0d8ef39ed60c..d228c605705d 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out @@ -784,3 +784,75 @@ select array_append(array(CAST(NULL AS String)), CAST(NULL AS String)) struct> -- !query output [null,null] + + +-- !query +select array_prepend(array(1, 2, 3), 4) +-- !query schema +struct> +-- !query output +[4,1,2,3] + + +-- !query +select array_prepend(array('a', 'b', 'c'), 'd') +-- !query schema +struct> +-- !query output +["d","a","b","c"] + + +-- !query +select array_prepend(array(1, 2, 3, NULL), NULL) +-- !query schema +struct> +-- !query output +[null,1,2,3,null] + + +-- !query +select array_prepend(array('a', 'b', 'c', NULL), NULL) +-- !query schema +struct> +-- !query output +[null,"a","b","c",null] + + +-- !query +select array_prepend(CAST(null AS ARRAY), 'a') +-- !query schema +struct> +-- !query output +NULL + + +-- !query +select array_prepend(CAST(null AS ARRAY), CAST(null as String)) +-- !query schema +struct> +-- !query output +NULL + + +-- !query +select array_prepend(array(), 1) +-- !query schema +struct> +-- !query output +[1] + + +-- !query +select array_prepend(CAST(array() AS ARRAY), CAST(NULL AS String)) +-- !query schema +struct> +-- !query output +[null] + + +-- !query +select array_prepend(array(CAST(NULL AS String)), CAST(NULL AS String)) +-- !query schema +struct> +-- !query output +[null,null] From f29996292e022dcabd8ca0a306e82901b5476fd4 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Thu, 9 Feb 2023 20:44:09 -0800 Subject: [PATCH 21/68] Fix tests --- python/pyspark/sql/functions.py | 9 +-- .../expressions/collectionOperations.scala | 3 +- .../sql-functions/sql-expression-schema.md | 2 +- .../resources/sql-tests/results/array.sql.out | 72 +++++++++++++++++++ 4 files changed, 80 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 548b0266d4ef..c8a709d27c7c 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7621,10 +7621,10 @@ def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: @try_remote_functions def array_prepend(col: "ColumnOrName", element: Any) -> Column: """ - Collection function: Returns an array containing element as - well as all elements from array. The new element is positioned + Collection function: Returns an array containing element as + well as all elements from array. The new element is positioned at the beginning of the array. - + .. versionadded:: 3.4.0 Parameters @@ -7636,6 +7636,7 @@ def array_prepend(col: "ColumnOrName", element: Any) -> Column: Returns ------- + :class:`~pyspark.sql.Column` an array excluding given value. @@ -7644,7 +7645,7 @@ def array_prepend(col: "ColumnOrName", element: Any) -> Column: >>> df = spark.createDataFrame([([2, 3, 4],), ([],)], ['data']) >>> df.select(array_prepend(df.data, 1)).collect() [Row(array_prepend(data, 1)=[1, 2, 3, 4]), Row(array_prepend(data, 1)=[1])] - """ + """ return _invoke_function("array_prepend", _to_java_column(col), element) @try_remote_functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index b122a629585b..6e2beda4bccd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1413,6 +1413,7 @@ case class ArrayContains(left: Expression, right: Expression) case class ArrayPrepend(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes + with ComplexTypeMergingExpression with QueryErrorsBase { override def nullable: Boolean = left.nullable @@ -1533,7 +1534,7 @@ case class ArrayPrepend(left: Expression, right: Expression) (left.dataType, right.dataType) match { case (_, NullType) => Seq.empty case (ArrayType(e1, hasNull), e2) => - TypeCoercion.findWiderTypeWithoutStringPromotionForTwo(e1, e2) match { + TypeCoercion.findTightestCommonType(e1, e2) match { case Some(dt) => Seq(ArrayType(dt, hasNull), dt) case _ => Seq.empty } diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index cf355e11fc4e..6146b7fcb9c0 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -13,7 +13,6 @@ | org.apache.spark.sql.catalyst.expressions.ArrayAggregate | aggregate | SELECT aggregate(array(1, 2, 3), 0, (acc, x) -> acc + x) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayAggregate | reduce | SELECT reduce(array(1, 2, 3), 0, (acc, x) -> acc + x) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayAppend | array_append | SELECT array_append(array('b', 'd', 'c', 'a'), 'd') | struct> | -| org.apache.spark.sql.catalyst.expressions.ArrayPrepend | array_prepend | SELECT array_prepend(array('b', 'd', 'c', 'a'), 'd') | struct> | | org.apache.spark.sql.catalyst.expressions.ArrayCompact | array_compact | SELECT array_compact(array(1, 2, 3, null)) | struct> | | org.apache.spark.sql.catalyst.expressions.ArrayContains | array_contains | SELECT array_contains(array(1, 2, 3), 2) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayDistinct | array_distinct | SELECT array_distinct(array(1, 2, 3, null, 3)) | struct> | @@ -27,6 +26,7 @@ | org.apache.spark.sql.catalyst.expressions.ArrayMax | array_max | SELECT array_max(array(1, 20, null, 3)) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayMin | array_min | SELECT array_min(array(1, 20, null, 3)) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayPosition | array_position | SELECT array_position(array(3, 2, 1), 1) | struct | +| org.apache.spark.sql.catalyst.expressions.ArrayPrepend | array_prepend | SELECT array_prepend(array('b', 'd', 'c', 'a'), 'd') | struct> | | org.apache.spark.sql.catalyst.expressions.ArrayRemove | array_remove | SELECT array_remove(array(1, 2, 3, null, 3), 3) | struct> | | org.apache.spark.sql.catalyst.expressions.ArrayRepeat | array_repeat | SELECT array_repeat('123', 2) | struct> | | org.apache.spark.sql.catalyst.expressions.ArraySize | array_size | SELECT array_size(array('b', 'd', 'c', 'a')) | struct | diff --git a/sql/core/src/test/resources/sql-tests/results/array.sql.out b/sql/core/src/test/resources/sql-tests/results/array.sql.out index 609122a23d31..029bd767f54c 100644 --- a/sql/core/src/test/resources/sql-tests/results/array.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/array.sql.out @@ -665,3 +665,75 @@ select array_append(array(CAST(NULL AS String)), CAST(NULL AS String)) struct> -- !query output [null,null] + + +-- !query +select array_prepend(array(1, 2, 3), 4) +-- !query schema +struct> +-- !query output +[4,1,2,3] + + +-- !query +select array_prepend(array('a', 'b', 'c'), 'd') +-- !query schema +struct> +-- !query output +["d","a","b","c"] + + +-- !query +select array_prepend(array(1, 2, 3, NULL), NULL) +-- !query schema +struct> +-- !query output +[null,1,2,3,null] + + +-- !query +select array_prepend(array('a', 'b', 'c', NULL), NULL) +-- !query schema +struct> +-- !query output +[null,"a","b","c",null] + + +-- !query +select array_prepend(CAST(null AS ARRAY), 'a') +-- !query schema +struct> +-- !query output +NULL + + +-- !query +select array_prepend(CAST(null AS ARRAY), CAST(null as String)) +-- !query schema +struct> +-- !query output +NULL + + +-- !query +select array_prepend(array(), 1) +-- !query schema +struct> +-- !query output +[1] + + +-- !query +select array_prepend(CAST(array() AS ARRAY), CAST(NULL AS String)) +-- !query schema +struct> +-- !query output +[null] + + +-- !query +select array_prepend(array(CAST(NULL AS String)), CAST(NULL AS String)) +-- !query schema +struct> +-- !query output +[null,null] From 7ce00b8bd561007bb00868f45748d9372ec35699 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Fri, 10 Feb 2023 22:03:34 -0800 Subject: [PATCH 22/68] Fix types --- python/pyspark/sql/functions.py | 2 +- .../spark/sql/catalyst/expressions/collectionOperations.scala | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index c8a709d27c7c..f9230f5478eb 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7624,7 +7624,7 @@ def array_prepend(col: "ColumnOrName", element: Any) -> Column: Collection function: Returns an array containing element as well as all elements from array. The new element is positioned at the beginning of the array. - + .. versionadded:: 3.4.0 Parameters diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 6e2beda4bccd..73be6327bca1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1532,7 +1532,6 @@ case class ArrayPrepend(left: Expression, right: Expression) } override def inputTypes: Seq[AbstractDataType] = { (left.dataType, right.dataType) match { - case (_, NullType) => Seq.empty case (ArrayType(e1, hasNull), e2) => TypeCoercion.findTightestCommonType(e1, e2) match { case Some(dt) => Seq(ArrayType(dt, hasNull), dt) From 505b8e23f51d5a89e4e623f3446b009e76c0a3c2 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Sun, 26 Feb 2023 16:46:34 -0800 Subject: [PATCH 23/68] Fix tests --- python/pyspark/sql/functions.py | 1 - .../spark/sql/catalyst/expressions/collectionOperations.scala | 4 ++-- .../scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index f9230f5478eb..294ec3669a98 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7636,7 +7636,6 @@ def array_prepend(col: "ColumnOrName", element: Any) -> Column: Returns ------- - :class:`~pyspark.sql.Column` an array excluding given value. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 73be6327bca1..068d18f1727b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1405,8 +1405,8 @@ case class ArrayContains(left: Expression, right: Expression) "_FUNC_(array, value) - Returns an array containing value as well as all elements from array. The new element is positioned at the beginning of the array.", examples = """ Examples: - > SELECT _FUNC_(array(1, 2, 3), 4); - [4, 1, 2, 3] + > SELECT _FUNC_(array('b', 'd', 'c', 'a'), 'd'); + ["d","b","d","c","a"] """, group = "array_funcs", since = "3.4.0") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index e34290396212..4fd350d8db26 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2692,7 +2692,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { }, errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( - "paramIndex" -> "1", + "paramIndex" -> "0", "sqlExpr" -> "\"array_prepend(_1, _2)\"", "inputSql" -> "\"_1\"", "inputType" -> "\"STRING\"", From 216ca4c32477fe438420705c6f9fe0b197aaf4fc Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Sun, 26 Feb 2023 19:18:02 -0800 Subject: [PATCH 24/68] Fix python linter --- python/pyspark/sql/functions.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 294ec3669a98..915470b06ca3 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7618,6 +7618,7 @@ def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: return _invoke_function_over_columns("get", col, index) + @try_remote_functions def array_prepend(col: "ColumnOrName", element: Any) -> Column: """ @@ -7644,9 +7645,10 @@ def array_prepend(col: "ColumnOrName", element: Any) -> Column: >>> df = spark.createDataFrame([([2, 3, 4],), ([],)], ['data']) >>> df.select(array_prepend(df.data, 1)).collect() [Row(array_prepend(data, 1)=[1, 2, 3, 4]), Row(array_prepend(data, 1)=[1])] - """ + """ return _invoke_function("array_prepend", _to_java_column(col), element) + @try_remote_functions def array_remove(col: "ColumnOrName", element: Any) -> Column: """ @@ -7677,6 +7679,7 @@ def array_remove(col: "ColumnOrName", element: Any) -> Column: """ return _invoke_function("array_remove", _to_java_column(col), element) + @try_remote_functions def array_distinct(col: "ColumnOrName") -> Column: """ From c121168e60f821fd013c38c0af6897949ef16319 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Sun, 26 Feb 2023 20:30:18 -0800 Subject: [PATCH 25/68] Add test for null cases --- .../expressions/collectionOperations.scala | 28 ++++++++----------- .../spark/sql/DataFrameFunctionsSuite.scala | 7 +++++ 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 068d18f1727b..0a4680193e01 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1418,6 +1418,9 @@ case class ArrayPrepend(left: Expression, right: Expression) override def nullable: Boolean = left.nullable + @transient protected lazy val elementType: DataType = + inputTypes.head.asInstanceOf[ArrayType].elementType + override def eval(input: InternalRow): Any = { val value1 = left.eval(input) if (value1 == null) { @@ -1427,23 +1430,16 @@ case class ArrayPrepend(left: Expression, right: Expression) nullSafeEval(value1, value2) } } - override def nullSafeEval(arr: Any, value: Any): Any = { - val numberOfElements = arr.asInstanceOf[ArrayData].numElements() - if (numberOfElements + 1 > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + override def nullSafeEval(arr: Any, elementData: Any): Any = { + val arrayData = arr.asInstanceOf[ArrayData] + val numberOfElements = arrayData.numElements() + 1 + if (numberOfElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { throw QueryExecutionErrors.concatArraysWithElementsExceedLimitError(numberOfElements) } - val newArray = new Array[Any](numberOfElements + 1) - newArray(0) = value - var pos = 1 - arr - .asInstanceOf[ArrayData] - .foreach( - right.dataType, - (i, v) => { - newArray(pos) = v - pos += 1 - }) - new GenericArrayData(newArray) + val finalData = new Array[Any](numberOfElements) + finalData.update(0, elementData) + arrayData.foreach(elementType, (i: Int, v: Any) => finalData.update(i + 1, v)) + new GenericArrayData(finalData) } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val leftGen = left.genCode(ctx) @@ -1505,7 +1501,7 @@ case class ArrayPrepend(left: Expression, right: Expression) newLeft: Expression, newRight: Expression): ArrayPrepend = copy(left = newLeft, right = newRight) - override def dataType: DataType = left.dataType + override def dataType: DataType = if (right.nullable) left.dataType.asNullable else left.dataType override def checkInputDataTypes(): TypeCheckResult = { (left.dataType, right.dataType) match { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 4fd350d8db26..bc096f923fa2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2710,6 +2710,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "leftType" -> "\"ARRAY\"", "rightType" -> "\"STRING\""), queryContext = Array(ExpectedContext("", "", 0, 30, "array_prepend(array(1, 2), '1')"))) + val df2 = Seq((Array[String]("a", "b", "c"), "d"), + (null, "d"), + (Array[String]("x", "y", "z"), null), + (null, null) + ).toDF("a", "b") + checkAnswer(df2.selectExpr("array_prepend(a, b)"), + Seq(Row(Seq("d", "a", "b", "c")), Row(null), Row(Seq(null, "x", "y", "z")), Row(null))) } test("array remove") { From ec503f9f7b718b02475eacb1fb6a7ce0ed52362c Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Sun, 26 Feb 2023 20:35:43 -0800 Subject: [PATCH 26/68] Fix type of array --- .../sql/catalyst/expressions/collectionOperations.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 0a4680193e01..d27f3d3f7851 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1451,13 +1451,13 @@ case class ArrayPrepend(left: Expression, right: Expression) val pos = ctx.freshName("pos") val allocation = CodeGenerator.createArrayData( newArray, - right.dataType, + elementType, newArraySize, s" $prettyName failed.") val assignment = - CodeGenerator.createArrayAssignment(newArray, right.dataType, arr, pos, i, false) + CodeGenerator.createArrayAssignment(newArray, elementType, arr, pos, i, false) val newElemAssignment = - CodeGenerator.setArrayElement(newArray, right.dataType, pos, value, Some(rightGen.isNull)) + CodeGenerator.setArrayElement(newArray, elementType, pos, value, Some(rightGen.isNull)) s""" |int $pos = 0; |int $newArraySize = $arr.numElements() + 1; From f1c01860c087cacc6025a9be9e2900502e29e5a4 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Tue, 28 Feb 2023 21:06:54 -0800 Subject: [PATCH 27/68] Address comments --- python/pyspark/sql/functions.py | 8 +++---- .../expressions/collectionOperations.scala | 22 ++++++++++++------- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 915470b06ca3..12142b9e2ca0 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7620,7 +7620,7 @@ def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: @try_remote_functions -def array_prepend(col: "ColumnOrName", element: Any) -> Column: +def array_prepend(col: "ColumnOrName", value: Any) -> Column: """ Collection function: Returns an array containing element as well as all elements from array. The new element is positioned @@ -7632,8 +7632,8 @@ def array_prepend(col: "ColumnOrName", element: Any) -> Column: ---------- col : :class:`~pyspark.sql.Column` or str name of column containing array - element : - element to be prepended to the array + value : + a literal value, or a :class:`~pyspark.sql.Column` expression. Returns ------- @@ -7646,7 +7646,7 @@ def array_prepend(col: "ColumnOrName", element: Any) -> Column: >>> df.select(array_prepend(df.data, 1)).collect() [Row(array_prepend(data, 1)=[1, 2, 3, 4]), Row(array_prepend(data, 1)=[1])] """ - return _invoke_function("array_prepend", _to_java_column(col), element) + return _invoke_function_over_columns("array_prepend", col, lit(value)) @try_remote_functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index d27f3d3f7851..62b1c5afaa08 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1401,12 +1401,20 @@ case class ArrayContains(left: Expression, right: Expression) // scalastyle:off line.size.limit @ExpressionDescription( - usage = - "_FUNC_(array, value) - Returns an array containing value as well as all elements from array. The new element is positioned at the beginning of the array.", + usage = """ + _FUNC_(array, element) - Add the element at the beginning of the array passed as first + argument. Type of element should be similar to type of the elements of the array. + Null element is also prepended to the array. But if the array passed is NULL + output is NULL + """, examples = """ Examples: > SELECT _FUNC_(array('b', 'd', 'c', 'a'), 'd'); ["d","b","d","c","a"] + > SELECT _FUNC_(array(1, 2, 3, null), null); + [null,1,2,3,null] + > SELECT _FUNC_(CAST(null as Array), 2); + NULL """, group = "array_funcs", since = "3.4.0") @@ -1448,25 +1456,23 @@ case class ArrayPrepend(left: Expression, right: Expression) val newArraySize = ctx.freshName("newArraySize") val newArray = ctx.freshName("newArray") val i = ctx.freshName("i") - val pos = ctx.freshName("pos") + val iPlus1 = s"$i+1" + val zero = "0" val allocation = CodeGenerator.createArrayData( newArray, elementType, newArraySize, s" $prettyName failed.") val assignment = - CodeGenerator.createArrayAssignment(newArray, elementType, arr, pos, i, false) + CodeGenerator.createArrayAssignment(newArray, elementType, arr, iPlus1, i, false) val newElemAssignment = - CodeGenerator.setArrayElement(newArray, elementType, pos, value, Some(rightGen.isNull)) + CodeGenerator.setArrayElement(newArray, elementType, zero, value, Some(rightGen.isNull)) s""" - |int $pos = 0; |int $newArraySize = $arr.numElements() + 1; |$allocation |$newElemAssignment - |$pos = $pos + 1; |for (int $i = 0; $i < $arr.numElements(); $i ++) { | $assignment - | $pos = $pos + 1; |} |${ev.value} = $newArray; |""".stripMargin From 34cb724a3538717b412eb549038ab841e3185437 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Sun, 12 Mar 2023 18:54:50 -0700 Subject: [PATCH 28/68] Update version --- python/pyspark/sql/functions.py | 2 +- .../spark/sql/catalyst/expressions/collectionOperations.scala | 2 +- sql/core/src/main/scala/org/apache/spark/sql/functions.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 12142b9e2ca0..dac7cddb880a 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7626,7 +7626,7 @@ def array_prepend(col: "ColumnOrName", value: Any) -> Column: well as all elements from array. The new element is positioned at the beginning of the array. - .. versionadded:: 3.4.0 + .. versionadded:: 3.5.0 Parameters ---------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 62b1c5afaa08..66efec732fe3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1417,7 +1417,7 @@ case class ArrayContains(left: Expression, right: Expression) NULL """, group = "array_funcs", - since = "3.4.0") + since = "3.5.0") case class ArrayPrepend(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 069e7b79fcb7..d771367f318c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -4048,7 +4048,7 @@ object functions { * positioned at the beginning of the array. * * @group collection_funcs - * @since 3.4.0 + * @since 3.5.0 */ def array_prepend(column: Column, element: Any): Column = withExpr { ArrayPrepend(column.expr, lit(element).expr) From baa6cc730af9389e4f695c6d3157642e1d238414 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Wed, 15 Mar 2023 22:34:27 -0700 Subject: [PATCH 29/68] Address review comments --- .../expressions/collectionOperations.scala | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 66efec732fe3..b621f3df100c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1403,7 +1403,7 @@ case class ArrayContains(left: Expression, right: Expression) @ExpressionDescription( usage = """ _FUNC_(array, element) - Add the element at the beginning of the array passed as first - argument. Type of element should be similar to type of the elements of the array. + argument. Type of element should be the same as the type of the elements of the array. Null element is also prepended to the array. But if the array passed is NULL output is NULL """, @@ -1453,7 +1453,7 @@ case class ArrayPrepend(left: Expression, right: Expression) val leftGen = left.genCode(ctx) val rightGen = right.genCode(ctx) val f = (arr: String, value: String) => { - val newArraySize = ctx.freshName("newArraySize") + val newArraySize = s"$arr.numElements() + 1" val newArray = ctx.freshName("newArray") val i = ctx.freshName("i") val iPlus1 = s"$i+1" @@ -1468,7 +1468,6 @@ case class ArrayPrepend(left: Expression, right: Expression) val newElemAssignment = CodeGenerator.setArrayElement(newArray, elementType, zero, value, Some(rightGen.isNull)) s""" - |int $newArraySize = $arr.numElements() + 1; |$allocation |$newElemAssignment |for (int $i = 0; $i < $arr.numElements(); $i ++) { @@ -1487,17 +1486,19 @@ case class ArrayPrepend(left: Expression, right: Expression) } ev.copy(code = code""" - boolean ${ev.isNull} = true; - ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - $nullSafeEval - """) + |boolean ${ev.isNull} = true; + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + |$nullSafeEval + """.stripMargin + ) } else { ev.copy(code = code""" - ${leftGen.code} - ${rightGen.code} - ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - $resultCode""", isNull = FalseLiteral) + |${leftGen.code} + |${rightGen.code} + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + |$resultCode + """.stripMargin, isNull = FalseLiteral) } } From 8aa8ae525de7e83dfd8f0e855069f053068282ee Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Fri, 27 Jan 2023 20:34:49 -0800 Subject: [PATCH 30/68] Adds a array_prepend expression to catalyst --- .../reference/pyspark.sql/functions.rst | 1 + python/pyspark/sql/functions.py | 27 ++++- .../catalyst/analysis/FunctionRegistry.scala | 12 ++ .../expressions/collectionOperations.scala | 113 ++++++++++++++++++ .../CollectionExpressionsSuite.scala | 41 +++++++ .../org/apache/spark/sql/functions.scala | 10 ++ .../spark/sql/DataFrameFunctionsSuite.scala | 52 ++++++++ 7 files changed, 255 insertions(+), 1 deletion(-) diff --git a/python/docs/source/reference/pyspark.sql/functions.rst b/python/docs/source/reference/pyspark.sql/functions.rst index 70fc04ef9cf2..cbc46e1fae18 100644 --- a/python/docs/source/reference/pyspark.sql/functions.rst +++ b/python/docs/source/reference/pyspark.sql/functions.rst @@ -159,6 +159,7 @@ Collection Functions array_sort array_insert array_remove + array_prepend array_distinct array_intersect array_union diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 051fd52a13c0..2ee7e44c670a 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7630,6 +7630,32 @@ def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: return _invoke_function_over_columns("get", col, index) +def array_prepend(col: "ColumnOrName", element: Any) -> Column: + """ + Collection function: Returns an array containing value as well as all elements from array. + The new element is positioned at the beginning of the array. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + name of column containing array + element : + element to be prepended to the array + + Returns + ------- + :class:`~pyspark.sql.Column` + an array excluding given value. + + Examples + -------- + >>> df = spark.createDataFrame([([2, 3, 4],), ([],)], ['data']) + >>> df.select(array_prepend(df.data, 1)).collect() + [Row(array_prepend(data, 1)=[1, 2, 3, 4]), Row(array_prepend(data, 1)=[1])] + """ + return _invoke_function("array_prepend", _to_java_column(col), element) @try_remote_functions def array_remove(col: "ColumnOrName", element: Any) -> Column: @@ -7661,7 +7687,6 @@ def array_remove(col: "ColumnOrName", element: Any) -> Column: """ return _invoke_function("array_remove", _to_java_column(col), element) - @try_remote_functions def array_distinct(col: "ColumnOrName") -> Column: """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index ad82a8361993..7ff11b15c6eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -697,6 +697,7 @@ object FunctionRegistry { expression[Sequence]("sequence"), expression[ArrayRepeat]("array_repeat"), expression[ArrayRemove]("array_remove"), + expression[ArrayPrepend]("array_prepend"), expression[ArrayDistinct]("array_distinct"), expression[ArrayTransform]("transform"), expression[MapFilter]("map_filter"), @@ -970,6 +971,7 @@ object TableFunctionRegistry { : (String, (ExpressionInfo, TableFunctionBuilder)) = { val (info, builder) = FunctionRegistryBase.build[T](name, since = None) val newBuilder = (expressions: Seq[Expression]) => { +<<<<<<< HEAD val generator = builder(expressions) assert(generator.isInstanceOf[Generator]) Generate( @@ -979,6 +981,16 @@ object TableFunctionRegistry { qualifier = None, generatorOutput = Nil, child = OneRowRelation()) +======= + try { + builder(expressions) + } catch { + case e: AnalysisException => + val argTypes = expressions.map(_.dataType.typeName).mkString(", ") + throw QueryCompilationErrors.cannotApplyTableValuedFunctionError( + name, argTypes, info.getUsage, e.getMessage) + } +>>>>>>> Revert "SPARK-41231: Adds an array_prepend function to catalyst" } (name, (info, newBuilder)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 289859d420bb..c003371c27a1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1399,6 +1399,119 @@ case class ArrayContains(left: Expression, right: Expression) copy(left = newLeft, right = newRight) } +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = + "_FUNC_(array, value) - Returns an array containing value as well as all elements from array. The new element is positioned at the beginning of the array.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), 4); + [4, 1, 2, 3] + """, + group = "array_funcs", + since = "3.4.0") +case class ArrayPrepend(left: Expression, right: Expression) + extends BinaryExpression + with ImplicitCastInputTypes + with NullIntolerant + with QueryErrorsBase { + + override def nullSafeEval(arr: Any, value: Any): Any = { + val numberOfElements = arr.asInstanceOf[ArrayData].numElements() + if (numberOfElements + 1 > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + throw QueryExecutionErrors.concatArraysWithElementsExceedLimitError(numberOfElements) + } + val newArray = new Array[Any](numberOfElements + 1) + newArray(0) = value + var pos = 1 + arr + .asInstanceOf[ArrayData] + .foreach( + right.dataType, + (i, v) => { + newArray(pos) = v + pos += 1 + }) + new GenericArrayData(newArray) + } + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen( + ctx, + ev, + (arr, value) => { + val newArraySize = ctx.freshName("newArraySize") + val newArray = ctx.freshName("newArray") + val i = ctx.freshName("i") + val pos = ctx.freshName("pos") + val allocation = CodeGenerator.createArrayData( + newArray, + right.dataType, + newArraySize, + s" $prettyName failed.") + val assignment = + CodeGenerator.createArrayAssignment(newArray, right.dataType, arr, pos, i, false) + val newElemAssignment = + CodeGenerator.setArrayElement(newArray, right.dataType, pos, value) + s""" + |int $pos = 0; + |int $newArraySize = $arr.numElements() + 1; + |$allocation + |$newElemAssignment + |$pos = $pos + 1; + |for (int $i = 0; $i < $arr.numElements(); $i ++) { + | $assignment + | $pos = $pos + 1; + |} + |${ev.value} = $newArray; + |""".stripMargin + }) + } + + override def prettyName: String = "array_prepend" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): ArrayPrepend = + copy(left = newLeft, right = newRight) + override def dataType: DataType = left.dataType + override def checkInputDataTypes(): TypeCheckResult = { + (left.dataType, right.dataType) match { + case (_, NullType) | (NullType, _) => + DataTypeMismatch( + errorSubClass = "NULL_TYPE", + messageParameters = Map("functionName" -> toSQLId(prettyName))) + case (l, _) if !ArrayType.acceptsType(l) => + DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> "1", + "requiredType" -> toSQLType(ArrayType), + "inputSql" -> toSQLExpr(left), + "inputType" -> toSQLType(left.dataType))) + case (ArrayType(e1, _), e2) if e1.sameType(e2) => + TypeUtils.checkForOrderingExpr(e2, prettyName) + case _ => + DataTypeMismatch( + errorSubClass = "ARRAY_FUNCTION_DIFF_TYPES", + messageParameters = Map( + "functionName" -> toSQLId(prettyName), + "dataType" -> toSQLType(ArrayType), + "leftType" -> toSQLType(left.dataType), + "rightType" -> toSQLType(right.dataType))) + } + } + override def inputTypes: Seq[AbstractDataType] = { + (left.dataType, right.dataType) match { + case (_, NullType) => Seq.empty + case (ArrayType(e1, hasNull), e2) => + TypeCoercion.findWiderTypeWithoutStringPromotionForTwo(e1, e2) match { + case Some(dt) => Seq(ArrayType(dt, hasNull), dt) + case _ => Seq.empty + } + case _ => Seq.empty + } + } +} + /** * Checks if the two arrays contain at least one common element. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 60300ba62f2f..63bfc76179f7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1855,6 +1855,47 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayRepeat(Literal("hi"), Literal(null, IntegerType)), null) } + test("SPARK-41233: ArrayPrepend") { + val a0 = Literal.create(Seq(1, 2, 3, 4), ArrayType(IntegerType)) + val a1 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType)) + val a2 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType)) + val a3 = Literal.create(null, ArrayType(StringType)) + + checkEvaluation(ArrayPrepend(a0, Literal(0)), Seq(0, 1, 2, 3, 4)) + checkEvaluation(ArrayPrepend(a1, Literal("a")), Seq("a", "a", "b", "c")) + checkEvaluation(ArrayPrepend(a2, Literal(1)), Seq(1)) + checkEvaluation(ArrayPrepend(a2, Literal(null, IntegerType)), null) + checkEvaluation(ArrayPrepend(a3, Literal("a")), null) + checkEvaluation(ArrayPrepend(a3, Literal(null, StringType)), null) + + // complex data types + val b0 = Literal.create( + Seq[Array[Byte]]( + Array[Byte](5, 6), + Array[Byte](1, 2), + Array[Byte](1, 2), + Array[Byte](5, 6)), + ArrayType(BinaryType)) + val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), null), ArrayType(BinaryType)) + val nullBinary = Literal.create(null, BinaryType) + checkEvaluation(ArrayPrepend(b0, nullBinary), null) + val dataToPrepend1 = Literal.create(Array[Byte](5, 6), BinaryType) + checkEvaluation( + ArrayPrepend(b1, dataToPrepend1), + Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](2, 1), null)) + + val c0 = Literal.create( + Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)), + ArrayType(ArrayType(IntegerType))) + val dataToPrepend2 = Literal.create(Seq[Int](5, 6), ArrayType(IntegerType)) + checkEvaluation( + ArrayPrepend(c0, dataToPrepend2), + Seq(Seq[Int](5, 6), Seq[Int](1, 2), Seq[Int](3, 4))) + checkEvaluation( + ArrayPrepend(c0, Literal.create(Seq.empty[Int], ArrayType(IntegerType))), + Seq(Seq.empty[Int], Seq[Int](1, 2), Seq[Int](3, 4))) + } + test("Array remove") { val a0 = Literal.create(Seq(1, 2, 3, 2, 2, 5), ArrayType(IntegerType)) val a1 = Literal.create(Seq("b", "a", "a", "c", "b"), ArrayType(StringType)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index cb5c1ad5c495..d2f1df8780b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -4042,6 +4042,16 @@ object functions { */ def array_compact(column: Column): Column = withExpr { ArrayCompact(column.expr) + + /** + * Returns an array containing value as well as all elements from array. The new element is + * positioned at the beginning of the array. + * + * @group collection_funcs + * @since 3.4.0 + */ + def array_prepend(column: Column, element: Any): Column = withExpr { + ArrayPrepend(column.expr, lit(element).expr) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index bd03d2928204..fcff7fb6adf2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2651,6 +2651,58 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { ) } + test("SPARK-41233: array prepend") { + val df = Seq( + (Array[Int](2, 3, 4), Array("b", "c", "d"), Array("", ""), 2), + (Array.empty[Int], Array.empty[String], Array.empty[String], 2), + (null, null, null, 2)).toDF("a", "b", "c", "d") + checkAnswer( + df.select(array_prepend($"a", 1), array_prepend($"b", "a"), array_prepend($"c", "")), + Seq( + Row(Seq(1, 2, 3, 4), Seq("a", "b", "c", "d"), Seq("", "", "")), + Row(Seq(1), Seq("a"), Seq("")), + Row(null, null, null))) + checkAnswer( + df.select(array_prepend($"a", $"d")), + Seq(Row(Seq(2, 2, 3, 4)), Row(Seq(2)), Row(null))) + checkAnswer( + df.selectExpr("array_prepend(a, d)"), + Seq(Row(Seq(2, 2, 3, 4)), Row(Seq(2)), Row(null))) + checkAnswer( + OneRowRelation().selectExpr("array_prepend(array(1, 2), 1.23D)"), + Seq(Row(Seq(1.23, 1.0, 2.0)))) + checkAnswer( + df.selectExpr("array_prepend(a, 1)", "array_prepend(b, \"a\")", "array_prepend(c, \"\")"), + Seq( + Row(Seq(1, 2, 3, 4), Seq("a", "b", "c", "d"), Seq("", "", "")), + Row(Seq(1), Seq("a"), Seq("")), + Row(null, null, null))) + checkError( + exception = intercept[AnalysisException] { + Seq(("a string element", "a")).toDF().selectExpr("array_prepend(_1, _2)") + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + parameters = Map( + "paramIndex" -> "1", + "sqlExpr" -> "\"array_prepend(_1, _2)\"", + "inputSql" -> "\"_1\"", + "inputType" -> "\"STRING\"", + "requiredType" -> "\"ARRAY\""), + queryContext = Array(ExpectedContext("", "", 0, 20, "array_prepend(_1, _2)"))) + checkError( + exception = intercept[AnalysisException] { + OneRowRelation().selectExpr("array_prepend(array(1, 2), '1')") + }, + errorClass = "DATATYPE_MISMATCH.ARRAY_FUNCTION_DIFF_TYPES", + parameters = Map( + "sqlExpr" -> "\"array_prepend(array(1, 2), 1)\"", + "functionName" -> "`array_prepend`", + "dataType" -> "\"ARRAY\"", + "leftType" -> "\"ARRAY\"", + "rightType" -> "\"STRING\""), + queryContext = Array(ExpectedContext("", "", 0, 30, "array_prepend(array(1, 2), '1')"))) + } + test("array remove") { val df = Seq( (Array[Int](2, 1, 2, 3), Array("a", "b", "c", "a"), Array("", ""), 2), From 1ba91c7e8755720d79c65a50321dcd735c65942c Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Fri, 27 Jan 2023 20:36:47 -0800 Subject: [PATCH 31/68] Fix null handling --- .../expressions/collectionOperations.scala | 122 +++++++++++------- .../CollectionExpressionsSuite.scala | 19 +-- .../org/apache/spark/sql/functions.scala | 15 +++ 3 files changed, 101 insertions(+), 55 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index c003371c27a1..6443f342e56a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1413,9 +1413,19 @@ case class ArrayContains(left: Expression, right: Expression) case class ArrayPrepend(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes - with NullIntolerant with QueryErrorsBase { + override def nullable: Boolean = left.nullable + + override def eval(input: InternalRow): Any = { + val value1 = left.eval(input) + if (value1 == null) { + null + } else { + val value2 = right.eval(input) + nullSafeEval(value1, value2) + } + } override def nullSafeEval(arr: Any, value: Any): Any = { val numberOfElements = arr.asInstanceOf[ArrayData].numElements() if (numberOfElements + 1 > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { @@ -1435,36 +1445,57 @@ case class ArrayPrepend(left: Expression, right: Expression) new GenericArrayData(newArray) } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen( - ctx, - ev, - (arr, value) => { - val newArraySize = ctx.freshName("newArraySize") - val newArray = ctx.freshName("newArray") - val i = ctx.freshName("i") - val pos = ctx.freshName("pos") - val allocation = CodeGenerator.createArrayData( - newArray, - right.dataType, - newArraySize, - s" $prettyName failed.") - val assignment = - CodeGenerator.createArrayAssignment(newArray, right.dataType, arr, pos, i, false) - val newElemAssignment = - CodeGenerator.setArrayElement(newArray, right.dataType, pos, value) + val leftGen = left.genCode(ctx) + val rightGen = right.genCode(ctx) + val f = (arr: String, value: String) => { + val newArraySize = ctx.freshName("newArraySize") + val newArray = ctx.freshName("newArray") + val i = ctx.freshName("i") + val pos = ctx.freshName("pos") + val allocation = CodeGenerator.createArrayData( + newArray, + right.dataType, + newArraySize, + s" $prettyName failed.") + val assignment = + CodeGenerator.createArrayAssignment(newArray, right.dataType, arr, pos, i, false) + val newElemAssignment = + CodeGenerator.setArrayElement(newArray, right.dataType, pos, value, Some(rightGen.isNull)) + s""" + |int $pos = 0; + |int $newArraySize = $arr.numElements() + 1; + |$allocation + |$newElemAssignment + |$pos = $pos + 1; + |for (int $i = 0; $i < $arr.numElements(); $i ++) { + | $assignment + | $pos = $pos + 1; + |} + |${ev.value} = $newArray; + |""".stripMargin + } + val resultCode = f(leftGen.value, rightGen.value) + if(nullable) { + val nullSafeEval = leftGen.code + rightGen.code + ctx.nullSafeExec(nullable, leftGen.isNull) { s""" - |int $pos = 0; - |int $newArraySize = $arr.numElements() + 1; - |$allocation - |$newElemAssignment - |$pos = $pos + 1; - |for (int $i = 0; $i < $arr.numElements(); $i ++) { - | $assignment - | $pos = $pos + 1; - |} - |${ev.value} = $newArray; + |${ev.isNull} = false; + |${resultCode} |""".stripMargin - }) + } + ev.copy(code = + code""" + boolean ${ev.isNull} = true; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + $nullSafeEval + """) + } else { + ev.copy(code = + code""" + ${leftGen.code} + ${rightGen.code} + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + $resultCode""", isNull = FalseLiteral) + } } override def prettyName: String = "array_prepend" @@ -1472,31 +1503,30 @@ case class ArrayPrepend(left: Expression, right: Expression) override protected def withNewChildrenInternal( newLeft: Expression, newRight: Expression): ArrayPrepend = copy(left = newLeft, right = newRight) + override def dataType: DataType = left.dataType + override def checkInputDataTypes(): TypeCheckResult = { (left.dataType, right.dataType) match { - case (_, NullType) | (NullType, _) => - DataTypeMismatch( - errorSubClass = "NULL_TYPE", - messageParameters = Map("functionName" -> toSQLId(prettyName))) - case (l, _) if !ArrayType.acceptsType(l) => + case (ArrayType(e1, _), e2) if e1.sameType(e2) => TypeCheckResult.TypeCheckSuccess + case (ArrayType(e1, _), e2) => DataTypeMismatch( + errorSubClass = "ARRAY_FUNCTION_DIFF_TYPES", + messageParameters = Map( + "functionName" -> toSQLId(prettyName), + "leftType" -> toSQLType(left.dataType), + "rightType" -> toSQLType(right.dataType), + "dataType" -> toSQLType(ArrayType) + )) + case _ => DataTypeMismatch( errorSubClass = "UNEXPECTED_INPUT_TYPE", messageParameters = Map( - "paramIndex" -> "1", + "paramIndex" -> "0", "requiredType" -> toSQLType(ArrayType), "inputSql" -> toSQLExpr(left), - "inputType" -> toSQLType(left.dataType))) - case (ArrayType(e1, _), e2) if e1.sameType(e2) => - TypeUtils.checkForOrderingExpr(e2, prettyName) - case _ => - DataTypeMismatch( - errorSubClass = "ARRAY_FUNCTION_DIFF_TYPES", - messageParameters = Map( - "functionName" -> toSQLId(prettyName), - "dataType" -> toSQLType(ArrayType), - "leftType" -> toSQLType(left.dataType), - "rightType" -> toSQLType(right.dataType))) + "inputType" -> toSQLType(left.dataType) + ) + ) } } override def inputTypes: Seq[AbstractDataType] = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 63bfc76179f7..1d00ec0cd8d2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -20,10 +20,8 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} import java.time.{Duration, LocalDateTime, Period} import java.util.TimeZone - import scala.language.implicitConversions import scala.util.Random - import org.apache.spark.{SparkFunSuite, SparkRuntimeException} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow @@ -1864,21 +1862,24 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayPrepend(a0, Literal(0)), Seq(0, 1, 2, 3, 4)) checkEvaluation(ArrayPrepend(a1, Literal("a")), Seq("a", "a", "b", "c")) checkEvaluation(ArrayPrepend(a2, Literal(1)), Seq(1)) - checkEvaluation(ArrayPrepend(a2, Literal(null, IntegerType)), null) + checkEvaluation(ArrayPrepend(a2, Literal(null, IntegerType)), Seq(null)) checkEvaluation(ArrayPrepend(a3, Literal("a")), null) checkEvaluation(ArrayPrepend(a3, Literal(null, StringType)), null) // complex data types + val data = Seq[Array[Byte]]( + Array[Byte](5, 6), + Array[Byte](1, 2), + Array[Byte](1, 2), + Array[Byte](5, 6)) val b0 = Literal.create( - Seq[Array[Byte]]( - Array[Byte](5, 6), - Array[Byte](1, 2), - Array[Byte](1, 2), - Array[Byte](5, 6)), + data, ArrayType(BinaryType)) val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), null), ArrayType(BinaryType)) val nullBinary = Literal.create(null, BinaryType) - checkEvaluation(ArrayPrepend(b0, nullBinary), null) + // Calling ArrayPrepend with a null element should result in NULL being prepended to the array + val dataWithNullPrepended = null +: data + checkEvaluation(ArrayPrepend(b0, nullBinary), dataWithNullPrepended) val dataToPrepend1 = Literal.create(Array[Byte](5, 6), BinaryType) checkEvaluation( ArrayPrepend(b1, dataToPrepend1), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index d2f1df8780b3..1f66a5daa2d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -4050,6 +4050,21 @@ object functions { * @group collection_funcs * @since 3.4.0 */ + def array_prepend(column: Column, element: Any): Column = withExpr { + ArrayPrepend(column.expr, lit(element).expr) + + /** + * Remove all null elements from the given array. + * + * @group collection_funcs + * @since 3.4.0 + */ + def array_compact(column: Column): Column = withExpr { + ArrayCompact(column.expr) + /** + * Returns an array containing value as well as all elements from array.The + * new element is positioned at the beginning of the array. + */ def array_prepend(column: Column, element: Any): Column = withExpr { ArrayPrepend(column.expr, lit(element).expr) } From 90c0c28bdc08f944c1cbfb5a151ec929decacb21 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Fri, 27 Jan 2023 20:41:01 -0800 Subject: [PATCH 32/68] Fix --- .../scala/org/apache/spark/sql/functions.scala | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 1f66a5daa2d2..069e7b79fcb7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -4042,7 +4042,7 @@ object functions { */ def array_compact(column: Column): Column = withExpr { ArrayCompact(column.expr) - + } /** * Returns an array containing value as well as all elements from array. The new element is * positioned at the beginning of the array. @@ -4050,21 +4050,6 @@ object functions { * @group collection_funcs * @since 3.4.0 */ - def array_prepend(column: Column, element: Any): Column = withExpr { - ArrayPrepend(column.expr, lit(element).expr) - - /** - * Remove all null elements from the given array. - * - * @group collection_funcs - * @since 3.4.0 - */ - def array_compact(column: Column): Column = withExpr { - ArrayCompact(column.expr) - /** - * Returns an array containing value as well as all elements from array.The - * new element is positioned at the beginning of the array. - */ def array_prepend(column: Column, element: Any): Column = withExpr { ArrayPrepend(column.expr, lit(element).expr) } From 0a69172a9205a005f7e5ba7cb90dd30e2ea72b53 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Fri, 27 Jan 2023 20:41:01 -0800 Subject: [PATCH 33/68] Fix --- .../sql/catalyst/analysis/FunctionRegistry.scala | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 7ff11b15c6eb..aca73741c639 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -971,7 +971,6 @@ object TableFunctionRegistry { : (String, (ExpressionInfo, TableFunctionBuilder)) = { val (info, builder) = FunctionRegistryBase.build[T](name, since = None) val newBuilder = (expressions: Seq[Expression]) => { -<<<<<<< HEAD val generator = builder(expressions) assert(generator.isInstanceOf[Generator]) Generate( @@ -981,16 +980,6 @@ object TableFunctionRegistry { qualifier = None, generatorOutput = Nil, child = OneRowRelation()) -======= - try { - builder(expressions) - } catch { - case e: AnalysisException => - val argTypes = expressions.map(_.dataType.typeName).mkString(", ") - throw QueryCompilationErrors.cannotApplyTableValuedFunctionError( - name, argTypes, info.getUsage, e.getMessage) - } ->>>>>>> Revert "SPARK-41231: Adds an array_prepend function to catalyst" } (name, (info, newBuilder)) } From db598804df9b6572940c7fd2637ef42c149d49d3 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Tue, 7 Feb 2023 20:45:20 -0800 Subject: [PATCH 34/68] Lint --- .../expressions/CollectionExpressionsSuite.scala | 1 + .../spark/sql/DataFrameFunctionsSuite.scala | 15 ++++++++++++--- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 1d00ec0cd8d2..fced26284885 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} import java.time.{Duration, LocalDateTime, Period} import java.util.TimeZone + import scala.language.implicitConversions import scala.util.Random import org.apache.spark.{SparkFunSuite, SparkRuntimeException} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index fcff7fb6adf2..c238f56123ea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2664,13 +2664,22 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Row(null, null, null))) checkAnswer( df.select(array_prepend($"a", $"d")), - Seq(Row(Seq(2, 2, 3, 4)), Row(Seq(2)), Row(null))) + Seq( + Row(Seq(2, 2, 3, 4)), + Row(Seq(2)), + Row(null))) checkAnswer( df.selectExpr("array_prepend(a, d)"), - Seq(Row(Seq(2, 2, 3, 4)), Row(Seq(2)), Row(null))) + Seq( + Row(Seq(2, 2, 3, 4)), + Row(Seq(2)), + Row(null))) checkAnswer( OneRowRelation().selectExpr("array_prepend(array(1, 2), 1.23D)"), - Seq(Row(Seq(1.23, 1.0, 2.0)))) + Seq( + Row(Seq(1.23, 1.0, 2.0)) + ) + ) checkAnswer( df.selectExpr("array_prepend(a, 1)", "array_prepend(b, \"a\")", "array_prepend(c, \"\")"), Seq( From f0d9329534d4af65eb8e312ae8a28be4fb435791 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Tue, 7 Feb 2023 20:47:55 -0800 Subject: [PATCH 35/68] Lint --- .../sql/catalyst/expressions/CollectionExpressionsSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index fced26284885..3abc70a3d551 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -23,6 +23,7 @@ import java.util.TimeZone import scala.language.implicitConversions import scala.util.Random + import org.apache.spark.{SparkFunSuite, SparkRuntimeException} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow From ae5b65e5b7b851836617149d573e6e5bf2c3cc92 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Wed, 8 Feb 2023 20:46:44 -0800 Subject: [PATCH 36/68] Add examples of usage and fix test --- python/pyspark/sql/functions.py | 6 +- .../sql-functions/sql-expression-schema.md | 3 +- .../test/resources/sql-tests/inputs/array.sql | 11 +++ .../sql-tests/results/ansi/array.sql.out | 72 +++++++++++++++++++ 4 files changed, 89 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 2ee7e44c670a..259a9b7dd601 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7630,10 +7630,12 @@ def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: return _invoke_function_over_columns("get", col, index) +@try_remote_functions def array_prepend(col: "ColumnOrName", element: Any) -> Column: """ - Collection function: Returns an array containing value as well as all elements from array. - The new element is positioned at the beginning of the array. + Collection function: Returns an array containing element as + well as all elements from array. The new element is positioned + at the beginning of the array. .. versionadded:: 3.4.0 diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 0894d03f9d41..529f4e044bbe 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -13,6 +13,7 @@ | org.apache.spark.sql.catalyst.expressions.ArrayAggregate | aggregate | SELECT aggregate(array(1, 2, 3), 0, (acc, x) -> acc + x) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayAggregate | reduce | SELECT reduce(array(1, 2, 3), 0, (acc, x) -> acc + x) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayAppend | array_append | SELECT array_append(array('b', 'd', 'c', 'a'), 'd') | struct> | +| org.apache.spark.sql.catalyst.expressions.ArrayPrepend | array_prepend | SELECT array_prepend(array('b', 'd', 'c', 'a'), 'd') | struct> | | org.apache.spark.sql.catalyst.expressions.ArrayCompact | array_compact | SELECT array_compact(array(1, 2, 3, null)) | struct> | | org.apache.spark.sql.catalyst.expressions.ArrayContains | array_contains | SELECT array_contains(array(1, 2, 3), 2) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayDistinct | array_distinct | SELECT array_distinct(array(1, 2, 3, null, 3)) | struct> | @@ -421,4 +422,4 @@ | org.apache.spark.sql.catalyst.expressions.xml.XPathList | xpath | SELECT xpath('b1b2b3c1c2','a/b/text()') | structb1b2b3c1c2, a/b/text()):array> | | org.apache.spark.sql.catalyst.expressions.xml.XPathLong | xpath_long | SELECT xpath_long('12', 'sum(a/b)') | struct12, sum(a/b)):bigint> | | org.apache.spark.sql.catalyst.expressions.xml.XPathShort | xpath_short | SELECT xpath_short('12', 'sum(a/b)') | struct12, sum(a/b)):smallint> | -| org.apache.spark.sql.catalyst.expressions.xml.XPathString | xpath_string | SELECT xpath_string('bcc','a/c') | structbcc, a/c):string> | \ No newline at end of file +| org.apache.spark.sql.catalyst.expressions.xml.XPathString | xpath_string | SELECT xpath_string('bcc','a/c') | structbcc, a/c):string> | diff --git a/sql/core/src/test/resources/sql-tests/inputs/array.sql b/sql/core/src/test/resources/sql-tests/inputs/array.sql index 3d107cb6dfc0..d3c36b79d1f3 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/array.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/array.sql @@ -160,3 +160,14 @@ select array_append(CAST(null AS ARRAY), CAST(null as String)); select array_append(array(), 1); select array_append(CAST(array() AS ARRAY), CAST(NULL AS String)); select array_append(array(CAST(NULL AS String)), CAST(NULL AS String)); + +-- function array_prepend +select array_prepend(array(1, 2, 3), 4); +select array_prepend(array('a', 'b', 'c'), 'd'); +select array_prepend(array(1, 2, 3, NULL), NULL); +select array_prepend(array('a', 'b', 'c', NULL), NULL); +select array_prepend(CAST(null AS ARRAY), 'a'); +select array_prepend(CAST(null AS ARRAY), CAST(null as String)); +select array_prepend(array(), 1); +select array_prepend(CAST(array() AS ARRAY), CAST(NULL AS String)); +select array_prepend(array(CAST(NULL AS String)), CAST(NULL AS String)); diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out index 0d8ef39ed60c..d228c605705d 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out @@ -784,3 +784,75 @@ select array_append(array(CAST(NULL AS String)), CAST(NULL AS String)) struct> -- !query output [null,null] + + +-- !query +select array_prepend(array(1, 2, 3), 4) +-- !query schema +struct> +-- !query output +[4,1,2,3] + + +-- !query +select array_prepend(array('a', 'b', 'c'), 'd') +-- !query schema +struct> +-- !query output +["d","a","b","c"] + + +-- !query +select array_prepend(array(1, 2, 3, NULL), NULL) +-- !query schema +struct> +-- !query output +[null,1,2,3,null] + + +-- !query +select array_prepend(array('a', 'b', 'c', NULL), NULL) +-- !query schema +struct> +-- !query output +[null,"a","b","c",null] + + +-- !query +select array_prepend(CAST(null AS ARRAY), 'a') +-- !query schema +struct> +-- !query output +NULL + + +-- !query +select array_prepend(CAST(null AS ARRAY), CAST(null as String)) +-- !query schema +struct> +-- !query output +NULL + + +-- !query +select array_prepend(array(), 1) +-- !query schema +struct> +-- !query output +[1] + + +-- !query +select array_prepend(CAST(array() AS ARRAY), CAST(NULL AS String)) +-- !query schema +struct> +-- !query output +[null] + + +-- !query +select array_prepend(array(CAST(NULL AS String)), CAST(NULL AS String)) +-- !query schema +struct> +-- !query output +[null,null] From af3ee0abbab06cb03cc0e23d41cb04722b738a1b Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Thu, 9 Feb 2023 20:44:09 -0800 Subject: [PATCH 37/68] Fix tests --- python/pyspark/sql/functions.py | 9 +-- .../expressions/collectionOperations.scala | 3 +- .../sql-functions/sql-expression-schema.md | 2 +- .../resources/sql-tests/results/array.sql.out | 72 +++++++++++++++++++ 4 files changed, 80 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 259a9b7dd601..80b806a02bdc 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7633,10 +7633,10 @@ def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: @try_remote_functions def array_prepend(col: "ColumnOrName", element: Any) -> Column: """ - Collection function: Returns an array containing element as - well as all elements from array. The new element is positioned + Collection function: Returns an array containing element as + well as all elements from array. The new element is positioned at the beginning of the array. - + .. versionadded:: 3.4.0 Parameters @@ -7648,6 +7648,7 @@ def array_prepend(col: "ColumnOrName", element: Any) -> Column: Returns ------- + :class:`~pyspark.sql.Column` an array excluding given value. @@ -7656,7 +7657,7 @@ def array_prepend(col: "ColumnOrName", element: Any) -> Column: >>> df = spark.createDataFrame([([2, 3, 4],), ([],)], ['data']) >>> df.select(array_prepend(df.data, 1)).collect() [Row(array_prepend(data, 1)=[1, 2, 3, 4]), Row(array_prepend(data, 1)=[1])] - """ + """ return _invoke_function("array_prepend", _to_java_column(col), element) @try_remote_functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 6443f342e56a..737608790ef2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1413,6 +1413,7 @@ case class ArrayContains(left: Expression, right: Expression) case class ArrayPrepend(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes + with ComplexTypeMergingExpression with QueryErrorsBase { override def nullable: Boolean = left.nullable @@ -1533,7 +1534,7 @@ case class ArrayPrepend(left: Expression, right: Expression) (left.dataType, right.dataType) match { case (_, NullType) => Seq.empty case (ArrayType(e1, hasNull), e2) => - TypeCoercion.findWiderTypeWithoutStringPromotionForTwo(e1, e2) match { + TypeCoercion.findTightestCommonType(e1, e2) match { case Some(dt) => Seq(ArrayType(dt, hasNull), dt) case _ => Seq.empty } diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 529f4e044bbe..6b5b67f98491 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -13,7 +13,6 @@ | org.apache.spark.sql.catalyst.expressions.ArrayAggregate | aggregate | SELECT aggregate(array(1, 2, 3), 0, (acc, x) -> acc + x) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayAggregate | reduce | SELECT reduce(array(1, 2, 3), 0, (acc, x) -> acc + x) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayAppend | array_append | SELECT array_append(array('b', 'd', 'c', 'a'), 'd') | struct> | -| org.apache.spark.sql.catalyst.expressions.ArrayPrepend | array_prepend | SELECT array_prepend(array('b', 'd', 'c', 'a'), 'd') | struct> | | org.apache.spark.sql.catalyst.expressions.ArrayCompact | array_compact | SELECT array_compact(array(1, 2, 3, null)) | struct> | | org.apache.spark.sql.catalyst.expressions.ArrayContains | array_contains | SELECT array_contains(array(1, 2, 3), 2) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayDistinct | array_distinct | SELECT array_distinct(array(1, 2, 3, null, 3)) | struct> | @@ -27,6 +26,7 @@ | org.apache.spark.sql.catalyst.expressions.ArrayMax | array_max | SELECT array_max(array(1, 20, null, 3)) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayMin | array_min | SELECT array_min(array(1, 20, null, 3)) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayPosition | array_position | SELECT array_position(array(3, 2, 1), 1) | struct | +| org.apache.spark.sql.catalyst.expressions.ArrayPrepend | array_prepend | SELECT array_prepend(array('b', 'd', 'c', 'a'), 'd') | struct> | | org.apache.spark.sql.catalyst.expressions.ArrayRemove | array_remove | SELECT array_remove(array(1, 2, 3, null, 3), 3) | struct> | | org.apache.spark.sql.catalyst.expressions.ArrayRepeat | array_repeat | SELECT array_repeat('123', 2) | struct> | | org.apache.spark.sql.catalyst.expressions.ArraySize | array_size | SELECT array_size(array('b', 'd', 'c', 'a')) | struct | diff --git a/sql/core/src/test/resources/sql-tests/results/array.sql.out b/sql/core/src/test/resources/sql-tests/results/array.sql.out index 609122a23d31..029bd767f54c 100644 --- a/sql/core/src/test/resources/sql-tests/results/array.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/array.sql.out @@ -665,3 +665,75 @@ select array_append(array(CAST(NULL AS String)), CAST(NULL AS String)) struct> -- !query output [null,null] + + +-- !query +select array_prepend(array(1, 2, 3), 4) +-- !query schema +struct> +-- !query output +[4,1,2,3] + + +-- !query +select array_prepend(array('a', 'b', 'c'), 'd') +-- !query schema +struct> +-- !query output +["d","a","b","c"] + + +-- !query +select array_prepend(array(1, 2, 3, NULL), NULL) +-- !query schema +struct> +-- !query output +[null,1,2,3,null] + + +-- !query +select array_prepend(array('a', 'b', 'c', NULL), NULL) +-- !query schema +struct> +-- !query output +[null,"a","b","c",null] + + +-- !query +select array_prepend(CAST(null AS ARRAY), 'a') +-- !query schema +struct> +-- !query output +NULL + + +-- !query +select array_prepend(CAST(null AS ARRAY), CAST(null as String)) +-- !query schema +struct> +-- !query output +NULL + + +-- !query +select array_prepend(array(), 1) +-- !query schema +struct> +-- !query output +[1] + + +-- !query +select array_prepend(CAST(array() AS ARRAY), CAST(NULL AS String)) +-- !query schema +struct> +-- !query output +[null] + + +-- !query +select array_prepend(array(CAST(NULL AS String)), CAST(NULL AS String)) +-- !query schema +struct> +-- !query output +[null,null] From 3265717d555f55f234f398c43f541f17e9442043 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Fri, 10 Feb 2023 22:03:34 -0800 Subject: [PATCH 38/68] Fix types --- python/pyspark/sql/functions.py | 2 +- .../spark/sql/catalyst/expressions/collectionOperations.scala | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 80b806a02bdc..cfd0f8674378 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7636,7 +7636,7 @@ def array_prepend(col: "ColumnOrName", element: Any) -> Column: Collection function: Returns an array containing element as well as all elements from array. The new element is positioned at the beginning of the array. - + .. versionadded:: 3.4.0 Parameters diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 737608790ef2..1052f8053050 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1532,7 +1532,6 @@ case class ArrayPrepend(left: Expression, right: Expression) } override def inputTypes: Seq[AbstractDataType] = { (left.dataType, right.dataType) match { - case (_, NullType) => Seq.empty case (ArrayType(e1, hasNull), e2) => TypeCoercion.findTightestCommonType(e1, e2) match { case Some(dt) => Seq(ArrayType(dt, hasNull), dt) From 7df63ea1df13f7f33f23fbc182cd5cced95d49e4 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Sun, 26 Feb 2023 16:46:34 -0800 Subject: [PATCH 39/68] Fix tests --- python/pyspark/sql/functions.py | 1 - .../spark/sql/catalyst/expressions/collectionOperations.scala | 4 ++-- .../scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index cfd0f8674378..554037eb0dff 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7648,7 +7648,6 @@ def array_prepend(col: "ColumnOrName", element: Any) -> Column: Returns ------- - :class:`~pyspark.sql.Column` an array excluding given value. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 1052f8053050..1c6e17a1bdc0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1405,8 +1405,8 @@ case class ArrayContains(left: Expression, right: Expression) "_FUNC_(array, value) - Returns an array containing value as well as all elements from array. The new element is positioned at the beginning of the array.", examples = """ Examples: - > SELECT _FUNC_(array(1, 2, 3), 4); - [4, 1, 2, 3] + > SELECT _FUNC_(array('b', 'd', 'c', 'a'), 'd'); + ["d","b","d","c","a"] """, group = "array_funcs", since = "3.4.0") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index c238f56123ea..a8929d5c8bdc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2692,7 +2692,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { }, errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( - "paramIndex" -> "1", + "paramIndex" -> "0", "sqlExpr" -> "\"array_prepend(_1, _2)\"", "inputSql" -> "\"_1\"", "inputType" -> "\"STRING\"", From b4fbbd509102ca847e2c3cc0310481d898fea4d7 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Sun, 26 Feb 2023 19:18:02 -0800 Subject: [PATCH 40/68] Fix python linter --- python/pyspark/sql/functions.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 554037eb0dff..6608f8d317e6 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7630,6 +7630,7 @@ def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: return _invoke_function_over_columns("get", col, index) + @try_remote_functions def array_prepend(col: "ColumnOrName", element: Any) -> Column: """ @@ -7656,9 +7657,10 @@ def array_prepend(col: "ColumnOrName", element: Any) -> Column: >>> df = spark.createDataFrame([([2, 3, 4],), ([],)], ['data']) >>> df.select(array_prepend(df.data, 1)).collect() [Row(array_prepend(data, 1)=[1, 2, 3, 4]), Row(array_prepend(data, 1)=[1])] - """ + """ return _invoke_function("array_prepend", _to_java_column(col), element) + @try_remote_functions def array_remove(col: "ColumnOrName", element: Any) -> Column: """ @@ -7689,6 +7691,7 @@ def array_remove(col: "ColumnOrName", element: Any) -> Column: """ return _invoke_function("array_remove", _to_java_column(col), element) + @try_remote_functions def array_distinct(col: "ColumnOrName") -> Column: """ From 413af39a50279f8ca3492077cba9dc8666796d0a Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Sun, 26 Feb 2023 20:30:18 -0800 Subject: [PATCH 41/68] Add test for null cases --- .../expressions/collectionOperations.scala | 28 ++++++++----------- .../spark/sql/DataFrameFunctionsSuite.scala | 7 +++++ 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 1c6e17a1bdc0..b52eaf3cff61 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1418,6 +1418,9 @@ case class ArrayPrepend(left: Expression, right: Expression) override def nullable: Boolean = left.nullable + @transient protected lazy val elementType: DataType = + inputTypes.head.asInstanceOf[ArrayType].elementType + override def eval(input: InternalRow): Any = { val value1 = left.eval(input) if (value1 == null) { @@ -1427,23 +1430,16 @@ case class ArrayPrepend(left: Expression, right: Expression) nullSafeEval(value1, value2) } } - override def nullSafeEval(arr: Any, value: Any): Any = { - val numberOfElements = arr.asInstanceOf[ArrayData].numElements() - if (numberOfElements + 1 > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + override def nullSafeEval(arr: Any, elementData: Any): Any = { + val arrayData = arr.asInstanceOf[ArrayData] + val numberOfElements = arrayData.numElements() + 1 + if (numberOfElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { throw QueryExecutionErrors.concatArraysWithElementsExceedLimitError(numberOfElements) } - val newArray = new Array[Any](numberOfElements + 1) - newArray(0) = value - var pos = 1 - arr - .asInstanceOf[ArrayData] - .foreach( - right.dataType, - (i, v) => { - newArray(pos) = v - pos += 1 - }) - new GenericArrayData(newArray) + val finalData = new Array[Any](numberOfElements) + finalData.update(0, elementData) + arrayData.foreach(elementType, (i: Int, v: Any) => finalData.update(i + 1, v)) + new GenericArrayData(finalData) } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val leftGen = left.genCode(ctx) @@ -1505,7 +1501,7 @@ case class ArrayPrepend(left: Expression, right: Expression) newLeft: Expression, newRight: Expression): ArrayPrepend = copy(left = newLeft, right = newRight) - override def dataType: DataType = left.dataType + override def dataType: DataType = if (right.nullable) left.dataType.asNullable else left.dataType override def checkInputDataTypes(): TypeCheckResult = { (left.dataType, right.dataType) match { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index a8929d5c8bdc..355f2dfffb57 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2710,6 +2710,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "leftType" -> "\"ARRAY\"", "rightType" -> "\"STRING\""), queryContext = Array(ExpectedContext("", "", 0, 30, "array_prepend(array(1, 2), '1')"))) + val df2 = Seq((Array[String]("a", "b", "c"), "d"), + (null, "d"), + (Array[String]("x", "y", "z"), null), + (null, null) + ).toDF("a", "b") + checkAnswer(df2.selectExpr("array_prepend(a, b)"), + Seq(Row(Seq("d", "a", "b", "c")), Row(null), Row(Seq(null, "x", "y", "z")), Row(null))) } test("array remove") { From 9992e33836e6e9df26d24f86a88366c5a8a0e037 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Sun, 26 Feb 2023 20:35:43 -0800 Subject: [PATCH 42/68] Fix type of array --- .../sql/catalyst/expressions/collectionOperations.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index b52eaf3cff61..f35c1c15243e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1451,13 +1451,13 @@ case class ArrayPrepend(left: Expression, right: Expression) val pos = ctx.freshName("pos") val allocation = CodeGenerator.createArrayData( newArray, - right.dataType, + elementType, newArraySize, s" $prettyName failed.") val assignment = - CodeGenerator.createArrayAssignment(newArray, right.dataType, arr, pos, i, false) + CodeGenerator.createArrayAssignment(newArray, elementType, arr, pos, i, false) val newElemAssignment = - CodeGenerator.setArrayElement(newArray, right.dataType, pos, value, Some(rightGen.isNull)) + CodeGenerator.setArrayElement(newArray, elementType, pos, value, Some(rightGen.isNull)) s""" |int $pos = 0; |int $newArraySize = $arr.numElements() + 1; From 684a7d96f658f3905ef5194b0e7aaa1639de0538 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Fri, 27 Jan 2023 20:34:49 -0800 Subject: [PATCH 43/68] Adds a array_prepend expression to catalyst --- python/pyspark/sql/functions.py | 27 ++++++++++++++++++- .../catalyst/analysis/FunctionRegistry.scala | 11 ++++++++ .../org/apache/spark/sql/functions.scala | 10 +++++++ 3 files changed, 47 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 6608f8d317e6..22215b5958a9 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7630,6 +7630,32 @@ def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: return _invoke_function_over_columns("get", col, index) +def array_prepend(col: "ColumnOrName", element: Any) -> Column: + """ + Collection function: Returns an array containing value as well as all elements from array. + The new element is positioned at the beginning of the array. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + name of column containing array + element : + element to be prepended to the array + + Returns + ------- + :class:`~pyspark.sql.Column` + an array excluding given value. + + Examples + -------- + >>> df = spark.createDataFrame([([2, 3, 4],), ([],)], ['data']) + >>> df.select(array_prepend(df.data, 1)).collect() + [Row(array_prepend(data, 1)=[1, 2, 3, 4]), Row(array_prepend(data, 1)=[1])] + """ + return _invoke_function("array_prepend", _to_java_column(col), element) @try_remote_functions def array_prepend(col: "ColumnOrName", element: Any) -> Column: @@ -7691,7 +7717,6 @@ def array_remove(col: "ColumnOrName", element: Any) -> Column: """ return _invoke_function("array_remove", _to_java_column(col), element) - @try_remote_functions def array_distinct(col: "ColumnOrName") -> Column: """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index aca73741c639..7ff11b15c6eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -971,6 +971,7 @@ object TableFunctionRegistry { : (String, (ExpressionInfo, TableFunctionBuilder)) = { val (info, builder) = FunctionRegistryBase.build[T](name, since = None) val newBuilder = (expressions: Seq[Expression]) => { +<<<<<<< HEAD val generator = builder(expressions) assert(generator.isInstanceOf[Generator]) Generate( @@ -980,6 +981,16 @@ object TableFunctionRegistry { qualifier = None, generatorOutput = Nil, child = OneRowRelation()) +======= + try { + builder(expressions) + } catch { + case e: AnalysisException => + val argTypes = expressions.map(_.dataType.typeName).mkString(", ") + throw QueryCompilationErrors.cannotApplyTableValuedFunctionError( + name, argTypes, info.getUsage, e.getMessage) + } +>>>>>>> Revert "SPARK-41231: Adds an array_prepend function to catalyst" } (name, (info, newBuilder)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 069e7b79fcb7..b0538af2fd6c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -4042,6 +4042,16 @@ object functions { */ def array_compact(column: Column): Column = withExpr { ArrayCompact(column.expr) + + /** + * Returns an array containing value as well as all elements from array. The new element is + * positioned at the beginning of the array. + * + * @group collection_funcs + * @since 3.4.0 + */ + def array_prepend(column: Column, element: Any): Column = withExpr { + ArrayPrepend(column.expr, lit(element).expr) } /** * Returns an array containing value as well as all elements from array. The new element is From 93f181917f9b300b8eb38f11328335a2d7b70a56 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Fri, 27 Jan 2023 20:36:47 -0800 Subject: [PATCH 44/68] Fix null handling --- .../expressions/CollectionExpressionsSuite.scala | 2 -- .../scala/org/apache/spark/sql/functions.scala | 15 +++++++++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 3abc70a3d551..1d00ec0cd8d2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -20,10 +20,8 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} import java.time.{Duration, LocalDateTime, Period} import java.util.TimeZone - import scala.language.implicitConversions import scala.util.Random - import org.apache.spark.{SparkFunSuite, SparkRuntimeException} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index b0538af2fd6c..f99fbbe8eefa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -4050,6 +4050,21 @@ object functions { * @group collection_funcs * @since 3.4.0 */ + def array_prepend(column: Column, element: Any): Column = withExpr { + ArrayPrepend(column.expr, lit(element).expr) + + /** + * Remove all null elements from the given array. + * + * @group collection_funcs + * @since 3.4.0 + */ + def array_compact(column: Column): Column = withExpr { + ArrayCompact(column.expr) + /** + * Returns an array containing value as well as all elements from array.The + * new element is positioned at the beginning of the array. + */ def array_prepend(column: Column, element: Any): Column = withExpr { ArrayPrepend(column.expr, lit(element).expr) } From a8db7b3823778430d94a614861f7a4dbbe04ae7f Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Fri, 27 Jan 2023 20:41:01 -0800 Subject: [PATCH 45/68] Fix --- .../scala/org/apache/spark/sql/functions.scala | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index f99fbbe8eefa..9c102767c422 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -4042,7 +4042,7 @@ object functions { */ def array_compact(column: Column): Column = withExpr { ArrayCompact(column.expr) - + } /** * Returns an array containing value as well as all elements from array. The new element is * positioned at the beginning of the array. @@ -4050,21 +4050,6 @@ object functions { * @group collection_funcs * @since 3.4.0 */ - def array_prepend(column: Column, element: Any): Column = withExpr { - ArrayPrepend(column.expr, lit(element).expr) - - /** - * Remove all null elements from the given array. - * - * @group collection_funcs - * @since 3.4.0 - */ - def array_compact(column: Column): Column = withExpr { - ArrayCompact(column.expr) - /** - * Returns an array containing value as well as all elements from array.The - * new element is positioned at the beginning of the array. - */ def array_prepend(column: Column, element: Any): Column = withExpr { ArrayPrepend(column.expr, lit(element).expr) } From 82186b9a4fd60f2886d630d951aeced9b6de36d3 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Fri, 27 Jan 2023 20:41:01 -0800 Subject: [PATCH 46/68] Fix --- .../sql/catalyst/analysis/FunctionRegistry.scala | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 7ff11b15c6eb..aca73741c639 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -971,7 +971,6 @@ object TableFunctionRegistry { : (String, (ExpressionInfo, TableFunctionBuilder)) = { val (info, builder) = FunctionRegistryBase.build[T](name, since = None) val newBuilder = (expressions: Seq[Expression]) => { -<<<<<<< HEAD val generator = builder(expressions) assert(generator.isInstanceOf[Generator]) Generate( @@ -981,16 +980,6 @@ object TableFunctionRegistry { qualifier = None, generatorOutput = Nil, child = OneRowRelation()) -======= - try { - builder(expressions) - } catch { - case e: AnalysisException => - val argTypes = expressions.map(_.dataType.typeName).mkString(", ") - throw QueryCompilationErrors.cannotApplyTableValuedFunctionError( - name, argTypes, info.getUsage, e.getMessage) - } ->>>>>>> Revert "SPARK-41231: Adds an array_prepend function to catalyst" } (name, (info, newBuilder)) } From 30988b7ad179cdb6f811393f318add51c72d8f49 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Tue, 7 Feb 2023 20:45:20 -0800 Subject: [PATCH 47/68] Lint --- .../sql/catalyst/expressions/CollectionExpressionsSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 1d00ec0cd8d2..fced26284885 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} import java.time.{Duration, LocalDateTime, Period} import java.util.TimeZone + import scala.language.implicitConversions import scala.util.Random import org.apache.spark.{SparkFunSuite, SparkRuntimeException} From 09a61cad634bb1874a4b3ae761e6922e5cd44be7 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Tue, 7 Feb 2023 20:47:55 -0800 Subject: [PATCH 48/68] Lint --- .../sql/catalyst/expressions/CollectionExpressionsSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index fced26284885..3abc70a3d551 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -23,6 +23,7 @@ import java.util.TimeZone import scala.language.implicitConversions import scala.util.Random + import org.apache.spark.{SparkFunSuite, SparkRuntimeException} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow From d188279dd096b3e2927da2c3de4c02445fd07760 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Wed, 8 Feb 2023 20:46:44 -0800 Subject: [PATCH 49/68] Add examples of usage and fix test --- python/pyspark/sql/functions.py | 6 ++++-- .../test/resources/sql-functions/sql-expression-schema.md | 1 + 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 22215b5958a9..b8b09182e907 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7630,10 +7630,12 @@ def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: return _invoke_function_over_columns("get", col, index) +@try_remote_functions def array_prepend(col: "ColumnOrName", element: Any) -> Column: """ - Collection function: Returns an array containing value as well as all elements from array. - The new element is positioned at the beginning of the array. + Collection function: Returns an array containing element as + well as all elements from array. The new element is positioned + at the beginning of the array. .. versionadded:: 3.4.0 diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 6b5b67f98491..0cbb896fe03d 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -13,6 +13,7 @@ | org.apache.spark.sql.catalyst.expressions.ArrayAggregate | aggregate | SELECT aggregate(array(1, 2, 3), 0, (acc, x) -> acc + x) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayAggregate | reduce | SELECT reduce(array(1, 2, 3), 0, (acc, x) -> acc + x) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayAppend | array_append | SELECT array_append(array('b', 'd', 'c', 'a'), 'd') | struct> | +| org.apache.spark.sql.catalyst.expressions.ArrayPrepend | array_prepend | SELECT array_prepend(array('b', 'd', 'c', 'a'), 'd') | struct> | | org.apache.spark.sql.catalyst.expressions.ArrayCompact | array_compact | SELECT array_compact(array(1, 2, 3, null)) | struct> | | org.apache.spark.sql.catalyst.expressions.ArrayContains | array_contains | SELECT array_contains(array(1, 2, 3), 2) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayDistinct | array_distinct | SELECT array_distinct(array(1, 2, 3, null, 3)) | struct> | From 15b713d81b5578ae6d77b816ace8900a3decc9cb Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Thu, 9 Feb 2023 20:44:09 -0800 Subject: [PATCH 50/68] Fix tests --- python/pyspark/sql/functions.py | 9 +++++---- .../resources/sql-functions/sql-expression-schema.md | 1 - 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index b8b09182e907..f57b9ee58c79 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7633,10 +7633,10 @@ def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: @try_remote_functions def array_prepend(col: "ColumnOrName", element: Any) -> Column: """ - Collection function: Returns an array containing element as - well as all elements from array. The new element is positioned + Collection function: Returns an array containing element as + well as all elements from array. The new element is positioned at the beginning of the array. - + .. versionadded:: 3.4.0 Parameters @@ -7648,6 +7648,7 @@ def array_prepend(col: "ColumnOrName", element: Any) -> Column: Returns ------- + :class:`~pyspark.sql.Column` an array excluding given value. @@ -7656,7 +7657,7 @@ def array_prepend(col: "ColumnOrName", element: Any) -> Column: >>> df = spark.createDataFrame([([2, 3, 4],), ([],)], ['data']) >>> df.select(array_prepend(df.data, 1)).collect() [Row(array_prepend(data, 1)=[1, 2, 3, 4]), Row(array_prepend(data, 1)=[1])] - """ + """ return _invoke_function("array_prepend", _to_java_column(col), element) @try_remote_functions diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 0cbb896fe03d..6b5b67f98491 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -13,7 +13,6 @@ | org.apache.spark.sql.catalyst.expressions.ArrayAggregate | aggregate | SELECT aggregate(array(1, 2, 3), 0, (acc, x) -> acc + x) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayAggregate | reduce | SELECT reduce(array(1, 2, 3), 0, (acc, x) -> acc + x) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayAppend | array_append | SELECT array_append(array('b', 'd', 'c', 'a'), 'd') | struct> | -| org.apache.spark.sql.catalyst.expressions.ArrayPrepend | array_prepend | SELECT array_prepend(array('b', 'd', 'c', 'a'), 'd') | struct> | | org.apache.spark.sql.catalyst.expressions.ArrayCompact | array_compact | SELECT array_compact(array(1, 2, 3, null)) | struct> | | org.apache.spark.sql.catalyst.expressions.ArrayContains | array_contains | SELECT array_contains(array(1, 2, 3), 2) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayDistinct | array_distinct | SELECT array_distinct(array(1, 2, 3, null, 3)) | struct> | From 380b15658decb3ba8e2b6a4aaaa5a417359ea965 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Fri, 10 Feb 2023 22:03:34 -0800 Subject: [PATCH 51/68] Fix types --- python/pyspark/sql/functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index f57b9ee58c79..f08beff97434 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7636,7 +7636,7 @@ def array_prepend(col: "ColumnOrName", element: Any) -> Column: Collection function: Returns an array containing element as well as all elements from array. The new element is positioned at the beginning of the array. - + .. versionadded:: 3.4.0 Parameters From 4ecfac854048e07002569bcb20d6e35567f85b97 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Sun, 26 Feb 2023 16:46:34 -0800 Subject: [PATCH 52/68] Fix tests --- python/pyspark/sql/functions.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index f08beff97434..2128e1ba1d1e 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7648,7 +7648,6 @@ def array_prepend(col: "ColumnOrName", element: Any) -> Column: Returns ------- - :class:`~pyspark.sql.Column` an array excluding given value. From 160db20dc6b3b128b4bdacf6718e63b271a8fdb4 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Sun, 26 Feb 2023 19:18:02 -0800 Subject: [PATCH 53/68] Fix python linter --- python/pyspark/sql/functions.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 2128e1ba1d1e..66f435f3614a 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7630,6 +7630,7 @@ def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: return _invoke_function_over_columns("get", col, index) + @try_remote_functions def array_prepend(col: "ColumnOrName", element: Any) -> Column: """ @@ -7656,9 +7657,10 @@ def array_prepend(col: "ColumnOrName", element: Any) -> Column: >>> df = spark.createDataFrame([([2, 3, 4],), ([],)], ['data']) >>> df.select(array_prepend(df.data, 1)).collect() [Row(array_prepend(data, 1)=[1, 2, 3, 4]), Row(array_prepend(data, 1)=[1])] - """ + """ return _invoke_function("array_prepend", _to_java_column(col), element) + @try_remote_functions def array_prepend(col: "ColumnOrName", element: Any) -> Column: """ @@ -7719,6 +7721,7 @@ def array_remove(col: "ColumnOrName", element: Any) -> Column: """ return _invoke_function("array_remove", _to_java_column(col), element) + @try_remote_functions def array_distinct(col: "ColumnOrName") -> Column: """ From 3aa673f5213106c66b060039baffe6825fd3832c Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Tue, 28 Feb 2023 21:06:54 -0800 Subject: [PATCH 54/68] Address comments --- python/pyspark/sql/functions.py | 8 +++---- .../expressions/collectionOperations.scala | 22 ++++++++++++------- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 66f435f3614a..aec0455fc20b 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7632,7 +7632,7 @@ def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: @try_remote_functions -def array_prepend(col: "ColumnOrName", element: Any) -> Column: +def array_prepend(col: "ColumnOrName", value: Any) -> Column: """ Collection function: Returns an array containing element as well as all elements from array. The new element is positioned @@ -7644,8 +7644,8 @@ def array_prepend(col: "ColumnOrName", element: Any) -> Column: ---------- col : :class:`~pyspark.sql.Column` or str name of column containing array - element : - element to be prepended to the array + value : + a literal value, or a :class:`~pyspark.sql.Column` expression. Returns ------- @@ -7658,7 +7658,7 @@ def array_prepend(col: "ColumnOrName", element: Any) -> Column: >>> df.select(array_prepend(df.data, 1)).collect() [Row(array_prepend(data, 1)=[1, 2, 3, 4]), Row(array_prepend(data, 1)=[1])] """ - return _invoke_function("array_prepend", _to_java_column(col), element) + return _invoke_function_over_columns("array_prepend", col, lit(value)) @try_remote_functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index f35c1c15243e..366f035a5cd7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1401,12 +1401,20 @@ case class ArrayContains(left: Expression, right: Expression) // scalastyle:off line.size.limit @ExpressionDescription( - usage = - "_FUNC_(array, value) - Returns an array containing value as well as all elements from array. The new element is positioned at the beginning of the array.", + usage = """ + _FUNC_(array, element) - Add the element at the beginning of the array passed as first + argument. Type of element should be similar to type of the elements of the array. + Null element is also prepended to the array. But if the array passed is NULL + output is NULL + """, examples = """ Examples: > SELECT _FUNC_(array('b', 'd', 'c', 'a'), 'd'); ["d","b","d","c","a"] + > SELECT _FUNC_(array(1, 2, 3, null), null); + [null,1,2,3,null] + > SELECT _FUNC_(CAST(null as Array), 2); + NULL """, group = "array_funcs", since = "3.4.0") @@ -1448,25 +1456,23 @@ case class ArrayPrepend(left: Expression, right: Expression) val newArraySize = ctx.freshName("newArraySize") val newArray = ctx.freshName("newArray") val i = ctx.freshName("i") - val pos = ctx.freshName("pos") + val iPlus1 = s"$i+1" + val zero = "0" val allocation = CodeGenerator.createArrayData( newArray, elementType, newArraySize, s" $prettyName failed.") val assignment = - CodeGenerator.createArrayAssignment(newArray, elementType, arr, pos, i, false) + CodeGenerator.createArrayAssignment(newArray, elementType, arr, iPlus1, i, false) val newElemAssignment = - CodeGenerator.setArrayElement(newArray, elementType, pos, value, Some(rightGen.isNull)) + CodeGenerator.setArrayElement(newArray, elementType, zero, value, Some(rightGen.isNull)) s""" - |int $pos = 0; |int $newArraySize = $arr.numElements() + 1; |$allocation |$newElemAssignment - |$pos = $pos + 1; |for (int $i = 0; $i < $arr.numElements(); $i ++) { | $assignment - | $pos = $pos + 1; |} |${ev.value} = $newArray; |""".stripMargin From 8b480a5ae50640a2a6120aebdadcaa306c750753 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Sun, 12 Mar 2023 18:54:50 -0700 Subject: [PATCH 55/68] Update version --- python/pyspark/sql/functions.py | 2 +- .../spark/sql/catalyst/expressions/collectionOperations.scala | 2 +- sql/core/src/main/scala/org/apache/spark/sql/functions.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index aec0455fc20b..624ff604cebe 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7638,7 +7638,7 @@ def array_prepend(col: "ColumnOrName", value: Any) -> Column: well as all elements from array. The new element is positioned at the beginning of the array. - .. versionadded:: 3.4.0 + .. versionadded:: 3.5.0 Parameters ---------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 366f035a5cd7..eaadef7c43b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1417,7 +1417,7 @@ case class ArrayContains(left: Expression, right: Expression) NULL """, group = "array_funcs", - since = "3.4.0") + since = "3.5.0") case class ArrayPrepend(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 9c102767c422..9674eda7c2af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -4048,7 +4048,7 @@ object functions { * positioned at the beginning of the array. * * @group collection_funcs - * @since 3.4.0 + * @since 3.5.0 */ def array_prepend(column: Column, element: Any): Column = withExpr { ArrayPrepend(column.expr, lit(element).expr) From 19fe92435de7cb4380baed9409bfd8b433d07044 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Wed, 15 Mar 2023 22:34:27 -0700 Subject: [PATCH 56/68] Address review comments --- .../expressions/collectionOperations.scala | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index eaadef7c43b2..2ccb3a6d0cd5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1403,7 +1403,7 @@ case class ArrayContains(left: Expression, right: Expression) @ExpressionDescription( usage = """ _FUNC_(array, element) - Add the element at the beginning of the array passed as first - argument. Type of element should be similar to type of the elements of the array. + argument. Type of element should be the same as the type of the elements of the array. Null element is also prepended to the array. But if the array passed is NULL output is NULL """, @@ -1453,7 +1453,7 @@ case class ArrayPrepend(left: Expression, right: Expression) val leftGen = left.genCode(ctx) val rightGen = right.genCode(ctx) val f = (arr: String, value: String) => { - val newArraySize = ctx.freshName("newArraySize") + val newArraySize = s"$arr.numElements() + 1" val newArray = ctx.freshName("newArray") val i = ctx.freshName("i") val iPlus1 = s"$i+1" @@ -1468,7 +1468,6 @@ case class ArrayPrepend(left: Expression, right: Expression) val newElemAssignment = CodeGenerator.setArrayElement(newArray, elementType, zero, value, Some(rightGen.isNull)) s""" - |int $newArraySize = $arr.numElements() + 1; |$allocation |$newElemAssignment |for (int $i = 0; $i < $arr.numElements(); $i ++) { @@ -1487,17 +1486,19 @@ case class ArrayPrepend(left: Expression, right: Expression) } ev.copy(code = code""" - boolean ${ev.isNull} = true; - ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - $nullSafeEval - """) + |boolean ${ev.isNull} = true; + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + |$nullSafeEval + """.stripMargin + ) } else { ev.copy(code = code""" - ${leftGen.code} - ${rightGen.code} - ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - $resultCode""", isNull = FalseLiteral) + |${leftGen.code} + |${rightGen.code} + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + |$resultCode + """.stripMargin, isNull = FalseLiteral) } } From b1cf31a6c799c2891e99eb76b1318ce2e11dc278 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Fri, 27 Jan 2023 20:34:49 -0800 Subject: [PATCH 57/68] Adds a array_prepend expression to catalyst --- python/pyspark/sql/functions.py | 27 ++++++++++++++++++- .../catalyst/analysis/FunctionRegistry.scala | 11 ++++++++ .../org/apache/spark/sql/functions.scala | 10 +++++++ 3 files changed, 47 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 624ff604cebe..d8ee5074c44f 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7630,6 +7630,32 @@ def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: return _invoke_function_over_columns("get", col, index) +def array_prepend(col: "ColumnOrName", element: Any) -> Column: + """ + Collection function: Returns an array containing value as well as all elements from array. + The new element is positioned at the beginning of the array. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + name of column containing array + element : + element to be prepended to the array + + Returns + ------- + :class:`~pyspark.sql.Column` + an array excluding given value. + + Examples + -------- + >>> df = spark.createDataFrame([([2, 3, 4],), ([],)], ['data']) + >>> df.select(array_prepend(df.data, 1)).collect() + [Row(array_prepend(data, 1)=[1, 2, 3, 4]), Row(array_prepend(data, 1)=[1])] + """ + return _invoke_function("array_prepend", _to_java_column(col), element) @try_remote_functions def array_prepend(col: "ColumnOrName", value: Any) -> Column: @@ -7721,7 +7747,6 @@ def array_remove(col: "ColumnOrName", element: Any) -> Column: """ return _invoke_function("array_remove", _to_java_column(col), element) - @try_remote_functions def array_distinct(col: "ColumnOrName") -> Column: """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index aca73741c639..7ff11b15c6eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -971,6 +971,7 @@ object TableFunctionRegistry { : (String, (ExpressionInfo, TableFunctionBuilder)) = { val (info, builder) = FunctionRegistryBase.build[T](name, since = None) val newBuilder = (expressions: Seq[Expression]) => { +<<<<<<< HEAD val generator = builder(expressions) assert(generator.isInstanceOf[Generator]) Generate( @@ -980,6 +981,16 @@ object TableFunctionRegistry { qualifier = None, generatorOutput = Nil, child = OneRowRelation()) +======= + try { + builder(expressions) + } catch { + case e: AnalysisException => + val argTypes = expressions.map(_.dataType.typeName).mkString(", ") + throw QueryCompilationErrors.cannotApplyTableValuedFunctionError( + name, argTypes, info.getUsage, e.getMessage) + } +>>>>>>> Revert "SPARK-41231: Adds an array_prepend function to catalyst" } (name, (info, newBuilder)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 9674eda7c2af..89aeccfe6ed4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -4042,6 +4042,16 @@ object functions { */ def array_compact(column: Column): Column = withExpr { ArrayCompact(column.expr) + + /** + * Returns an array containing value as well as all elements from array. The new element is + * positioned at the beginning of the array. + * + * @group collection_funcs + * @since 3.4.0 + */ + def array_prepend(column: Column, element: Any): Column = withExpr { + ArrayPrepend(column.expr, lit(element).expr) } /** * Returns an array containing value as well as all elements from array. The new element is From c0c6a512871651be2e4f834b41ad64c62b3abe09 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Fri, 27 Jan 2023 20:36:47 -0800 Subject: [PATCH 58/68] Fix null handling --- .../expressions/CollectionExpressionsSuite.scala | 2 -- .../scala/org/apache/spark/sql/functions.scala | 15 +++++++++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 3abc70a3d551..1d00ec0cd8d2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -20,10 +20,8 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} import java.time.{Duration, LocalDateTime, Period} import java.util.TimeZone - import scala.language.implicitConversions import scala.util.Random - import org.apache.spark.{SparkFunSuite, SparkRuntimeException} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 89aeccfe6ed4..3aa778382d1c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -4050,6 +4050,21 @@ object functions { * @group collection_funcs * @since 3.4.0 */ + def array_prepend(column: Column, element: Any): Column = withExpr { + ArrayPrepend(column.expr, lit(element).expr) + + /** + * Remove all null elements from the given array. + * + * @group collection_funcs + * @since 3.4.0 + */ + def array_compact(column: Column): Column = withExpr { + ArrayCompact(column.expr) + /** + * Returns an array containing value as well as all elements from array.The + * new element is positioned at the beginning of the array. + */ def array_prepend(column: Column, element: Any): Column = withExpr { ArrayPrepend(column.expr, lit(element).expr) } From 422f393575fb82635169e67da0e5c2a5858815a0 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Fri, 27 Jan 2023 20:41:01 -0800 Subject: [PATCH 59/68] Fix --- .../scala/org/apache/spark/sql/functions.scala | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 3aa778382d1c..95a976d6ba32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -4042,7 +4042,7 @@ object functions { */ def array_compact(column: Column): Column = withExpr { ArrayCompact(column.expr) - + } /** * Returns an array containing value as well as all elements from array. The new element is * positioned at the beginning of the array. @@ -4050,21 +4050,6 @@ object functions { * @group collection_funcs * @since 3.4.0 */ - def array_prepend(column: Column, element: Any): Column = withExpr { - ArrayPrepend(column.expr, lit(element).expr) - - /** - * Remove all null elements from the given array. - * - * @group collection_funcs - * @since 3.4.0 - */ - def array_compact(column: Column): Column = withExpr { - ArrayCompact(column.expr) - /** - * Returns an array containing value as well as all elements from array.The - * new element is positioned at the beginning of the array. - */ def array_prepend(column: Column, element: Any): Column = withExpr { ArrayPrepend(column.expr, lit(element).expr) } From 2e193e4cdac011233506e5b0271c1eceda5ac353 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Tue, 7 Feb 2023 20:45:20 -0800 Subject: [PATCH 60/68] Lint --- .../sql/catalyst/expressions/CollectionExpressionsSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 1d00ec0cd8d2..fced26284885 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} import java.time.{Duration, LocalDateTime, Period} import java.util.TimeZone + import scala.language.implicitConversions import scala.util.Random import org.apache.spark.{SparkFunSuite, SparkRuntimeException} From 95673b826bd58f126fde095e15c9338f663deaa8 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Wed, 8 Feb 2023 20:46:44 -0800 Subject: [PATCH 61/68] Add examples of usage and fix test --- python/pyspark/sql/functions.py | 6 ++++-- .../test/resources/sql-functions/sql-expression-schema.md | 1 + 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index d8ee5074c44f..f290f22e7b4b 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7630,10 +7630,12 @@ def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: return _invoke_function_over_columns("get", col, index) +@try_remote_functions def array_prepend(col: "ColumnOrName", element: Any) -> Column: """ - Collection function: Returns an array containing value as well as all elements from array. - The new element is positioned at the beginning of the array. + Collection function: Returns an array containing element as + well as all elements from array. The new element is positioned + at the beginning of the array. .. versionadded:: 3.4.0 diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 6b5b67f98491..0cbb896fe03d 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -13,6 +13,7 @@ | org.apache.spark.sql.catalyst.expressions.ArrayAggregate | aggregate | SELECT aggregate(array(1, 2, 3), 0, (acc, x) -> acc + x) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayAggregate | reduce | SELECT reduce(array(1, 2, 3), 0, (acc, x) -> acc + x) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayAppend | array_append | SELECT array_append(array('b', 'd', 'c', 'a'), 'd') | struct> | +| org.apache.spark.sql.catalyst.expressions.ArrayPrepend | array_prepend | SELECT array_prepend(array('b', 'd', 'c', 'a'), 'd') | struct> | | org.apache.spark.sql.catalyst.expressions.ArrayCompact | array_compact | SELECT array_compact(array(1, 2, 3, null)) | struct> | | org.apache.spark.sql.catalyst.expressions.ArrayContains | array_contains | SELECT array_contains(array(1, 2, 3), 2) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayDistinct | array_distinct | SELECT array_distinct(array(1, 2, 3, null, 3)) | struct> | From 46e6dd78177b98e20ba421c98ad35b999c285343 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Thu, 9 Feb 2023 20:44:09 -0800 Subject: [PATCH 62/68] Fix tests --- python/pyspark/sql/functions.py | 9 +++++---- .../resources/sql-functions/sql-expression-schema.md | 1 - 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index f290f22e7b4b..3b2e0e7ee760 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7633,10 +7633,10 @@ def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: @try_remote_functions def array_prepend(col: "ColumnOrName", element: Any) -> Column: """ - Collection function: Returns an array containing element as - well as all elements from array. The new element is positioned + Collection function: Returns an array containing element as + well as all elements from array. The new element is positioned at the beginning of the array. - + .. versionadded:: 3.4.0 Parameters @@ -7648,6 +7648,7 @@ def array_prepend(col: "ColumnOrName", element: Any) -> Column: Returns ------- + :class:`~pyspark.sql.Column` an array excluding given value. @@ -7656,7 +7657,7 @@ def array_prepend(col: "ColumnOrName", element: Any) -> Column: >>> df = spark.createDataFrame([([2, 3, 4],), ([],)], ['data']) >>> df.select(array_prepend(df.data, 1)).collect() [Row(array_prepend(data, 1)=[1, 2, 3, 4]), Row(array_prepend(data, 1)=[1])] - """ + """ return _invoke_function("array_prepend", _to_java_column(col), element) @try_remote_functions diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 0cbb896fe03d..6b5b67f98491 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -13,7 +13,6 @@ | org.apache.spark.sql.catalyst.expressions.ArrayAggregate | aggregate | SELECT aggregate(array(1, 2, 3), 0, (acc, x) -> acc + x) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayAggregate | reduce | SELECT reduce(array(1, 2, 3), 0, (acc, x) -> acc + x) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayAppend | array_append | SELECT array_append(array('b', 'd', 'c', 'a'), 'd') | struct> | -| org.apache.spark.sql.catalyst.expressions.ArrayPrepend | array_prepend | SELECT array_prepend(array('b', 'd', 'c', 'a'), 'd') | struct> | | org.apache.spark.sql.catalyst.expressions.ArrayCompact | array_compact | SELECT array_compact(array(1, 2, 3, null)) | struct> | | org.apache.spark.sql.catalyst.expressions.ArrayContains | array_contains | SELECT array_contains(array(1, 2, 3), 2) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayDistinct | array_distinct | SELECT array_distinct(array(1, 2, 3, null, 3)) | struct> | From 67a64daf795f87f3dc5ccd71e4c35c3365adb20b Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Fri, 10 Feb 2023 22:03:34 -0800 Subject: [PATCH 63/68] Fix types --- python/pyspark/sql/functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 3b2e0e7ee760..4e3354d8010d 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7636,7 +7636,7 @@ def array_prepend(col: "ColumnOrName", element: Any) -> Column: Collection function: Returns an array containing element as well as all elements from array. The new element is positioned at the beginning of the array. - + .. versionadded:: 3.4.0 Parameters From 19505ff09da6bb7636c6d1d82f6c8a33727b4871 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Sun, 26 Feb 2023 16:46:34 -0800 Subject: [PATCH 64/68] Fix tests --- python/pyspark/sql/functions.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 4e3354d8010d..bcb88930cdcd 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7648,7 +7648,6 @@ def array_prepend(col: "ColumnOrName", element: Any) -> Column: Returns ------- - :class:`~pyspark.sql.Column` an array excluding given value. From 52078ff73a36893ab65f6ccc939e693b4328ba84 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Sun, 26 Feb 2023 19:18:02 -0800 Subject: [PATCH 65/68] Fix python linter --- python/pyspark/sql/functions.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index bcb88930cdcd..d920f0bf7c31 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7630,6 +7630,7 @@ def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: return _invoke_function_over_columns("get", col, index) + @try_remote_functions def array_prepend(col: "ColumnOrName", element: Any) -> Column: """ @@ -7656,9 +7657,10 @@ def array_prepend(col: "ColumnOrName", element: Any) -> Column: >>> df = spark.createDataFrame([([2, 3, 4],), ([],)], ['data']) >>> df.select(array_prepend(df.data, 1)).collect() [Row(array_prepend(data, 1)=[1, 2, 3, 4]), Row(array_prepend(data, 1)=[1])] - """ + """ return _invoke_function("array_prepend", _to_java_column(col), element) + @try_remote_functions def array_prepend(col: "ColumnOrName", value: Any) -> Column: """ @@ -7749,6 +7751,7 @@ def array_remove(col: "ColumnOrName", element: Any) -> Column: """ return _invoke_function("array_remove", _to_java_column(col), element) + @try_remote_functions def array_distinct(col: "ColumnOrName") -> Column: """ From 8cd56bdc62306429f5a0dd310cfa671880787cab Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Thu, 16 Mar 2023 21:12:58 -0700 Subject: [PATCH 66/68] Fix merge --- .../sql/catalyst/analysis/FunctionRegistry.scala | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 7ff11b15c6eb..aca73741c639 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -971,7 +971,6 @@ object TableFunctionRegistry { : (String, (ExpressionInfo, TableFunctionBuilder)) = { val (info, builder) = FunctionRegistryBase.build[T](name, since = None) val newBuilder = (expressions: Seq[Expression]) => { -<<<<<<< HEAD val generator = builder(expressions) assert(generator.isInstanceOf[Generator]) Generate( @@ -981,16 +980,6 @@ object TableFunctionRegistry { qualifier = None, generatorOutput = Nil, child = OneRowRelation()) -======= - try { - builder(expressions) - } catch { - case e: AnalysisException => - val argTypes = expressions.map(_.dataType.typeName).mkString(", ") - throw QueryCompilationErrors.cannotApplyTableValuedFunctionError( - name, argTypes, info.getUsage, e.getMessage) - } ->>>>>>> Revert "SPARK-41231: Adds an array_prepend function to catalyst" } (name, (info, newBuilder)) } From 663473798517909bd0aee150703987bd62c176ce Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Thu, 16 Mar 2023 22:36:30 -0700 Subject: [PATCH 67/68] Fix MiMa --- .../sql/connect/client/CheckConnectJvmClientCompatibility.scala | 1 + .../sql/catalyst/expressions/CollectionExpressionsSuite.scala | 1 + 2 files changed, 2 insertions(+) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala index 97d130421a24..f50520c1a54d 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala @@ -177,6 +177,7 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.broadcast"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.typedlit"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.typedLit"), + ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.array_prepend"), // RelationalGroupedDataset ProblemFilters.exclude[Problem]("org.apache.spark.sql.RelationalGroupedDataset.apply"), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index fced26284885..3abc70a3d551 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -23,6 +23,7 @@ import java.util.TimeZone import scala.language.implicitConversions import scala.util.Random + import org.apache.spark.{SparkFunSuite, SparkRuntimeException} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow From 4dffbf70096df37a87a424c3fa77aff75f485386 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Fri, 17 Mar 2023 08:25:48 -0700 Subject: [PATCH 68/68] Fix indent --- .../scala/org/apache/spark/sql/functions.scala | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index d771367f318c..5081f5822024 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -4043,13 +4043,14 @@ object functions { def array_compact(column: Column): Column = withExpr { ArrayCompact(column.expr) } - /** - * Returns an array containing value as well as all elements from array. The new element is - * positioned at the beginning of the array. - * - * @group collection_funcs - * @since 3.5.0 - */ + + /** + * Returns an array containing value as well as all elements from array. The new element is + * positioned at the beginning of the array. + * + * @group collection_funcs + * @since 3.5.0 + */ def array_prepend(column: Column, element: Any): Column = withExpr { ArrayPrepend(column.expr, lit(element).expr) }