-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-41233][SQL][PYTHON] Add array_prepend function
#38947
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
78f960b
f2d4f68
85a8b4c
d7b601c
c703f85
d827d21
4e7aa1e
89e8917
ff8c19e
4f3a968
ec9ea76
c728581
6eba188
63ff6cc
73a7dd7
6f97761
fec08e9
a8da345
3af8fd9
2cd3e18
a307cb4
7b24500
f299962
7ce00b8
505b8e2
216ca4c
c121168
ec503f9
86a948f
f1c0186
34cb724
baa6cc7
8aa8ae5
1ba91c7
90c0c28
0a69172
db59880
f0d9329
ae5b65e
af3ee0a
3265717
7df63ea
b4fbbd5
413af39
9992e33
684a7d9
93f1819
a8db7b3
82186b9
30988b7
09a61ca
d188279
15b713d
380b156
4ecfac8
160db20
3aa673f
8b480a5
19fe924
b1cf31a
c0c6a51
422f393
2e193e4
95673b8
46e6dd7
67a64da
19505ff
52078ff
e0244ad
8cd56bd
6634737
4dffbf7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1399,6 +1399,152 @@ case class ArrayContains(left: Expression, right: Expression) | |
| copy(left = newLeft, right = newRight) | ||
| } | ||
|
|
||
| // scalastyle:off line.size.limit | ||
| @ExpressionDescription( | ||
navinvishy marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| usage = """ | ||
| _FUNC_(array, element) - Add the element at the beginning of the array passed as first | ||
| argument. Type of element should be the same as the type of the elements of the array. | ||
| Null element is also prepended to the array. But if the array passed is NULL | ||
| output is NULL | ||
| """, | ||
| examples = """ | ||
| Examples: | ||
| > SELECT _FUNC_(array('b', 'd', 'c', 'a'), 'd'); | ||
| ["d","b","d","c","a"] | ||
| > SELECT _FUNC_(array(1, 2, 3, null), null); | ||
| [null,1,2,3,null] | ||
| > SELECT _FUNC_(CAST(null as Array<Int>), 2); | ||
| NULL | ||
| """, | ||
| group = "array_funcs", | ||
| since = "3.5.0") | ||
| case class ArrayPrepend(left: Expression, right: Expression) | ||
| extends BinaryExpression | ||
| with ImplicitCastInputTypes | ||
| with ComplexTypeMergingExpression | ||
| with QueryErrorsBase { | ||
|
|
||
| override def nullable: Boolean = left.nullable | ||
|
|
||
| @transient protected lazy val elementType: DataType = | ||
| inputTypes.head.asInstanceOf[ArrayType].elementType | ||
|
|
||
| override def eval(input: InternalRow): Any = { | ||
| val value1 = left.eval(input) | ||
| if (value1 == null) { | ||
| null | ||
| } else { | ||
| val value2 = right.eval(input) | ||
| nullSafeEval(value1, value2) | ||
| } | ||
| } | ||
| override def nullSafeEval(arr: Any, elementData: Any): Any = { | ||
| val arrayData = arr.asInstanceOf[ArrayData] | ||
| val numberOfElements = arrayData.numElements() + 1 | ||
| if (numberOfElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { | ||
| throw QueryExecutionErrors.concatArraysWithElementsExceedLimitError(numberOfElements) | ||
| } | ||
| val finalData = new Array[Any](numberOfElements) | ||
| finalData.update(0, elementData) | ||
| arrayData.foreach(elementType, (i: Int, v: Any) => finalData.update(i + 1, v)) | ||
| new GenericArrayData(finalData) | ||
| } | ||
| override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { | ||
| val leftGen = left.genCode(ctx) | ||
| val rightGen = right.genCode(ctx) | ||
| val f = (arr: String, value: String) => { | ||
| val newArraySize = s"$arr.numElements() + 1" | ||
| val newArray = ctx.freshName("newArray") | ||
| val i = ctx.freshName("i") | ||
| val iPlus1 = s"$i+1" | ||
| val zero = "0" | ||
| val allocation = CodeGenerator.createArrayData( | ||
| newArray, | ||
| elementType, | ||
| newArraySize, | ||
navinvishy marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| s" $prettyName failed.") | ||
| val assignment = | ||
| CodeGenerator.createArrayAssignment(newArray, elementType, arr, iPlus1, i, false) | ||
| val newElemAssignment = | ||
| CodeGenerator.setArrayElement(newArray, elementType, zero, value, Some(rightGen.isNull)) | ||
| s""" | ||
| |$allocation | ||
| |$newElemAssignment | ||
| |for (int $i = 0; $i < $arr.numElements(); $i ++) { | ||
| | $assignment | ||
| |} | ||
| |${ev.value} = $newArray; | ||
| |""".stripMargin | ||
| } | ||
| val resultCode = f(leftGen.value, rightGen.value) | ||
| if(nullable) { | ||
| val nullSafeEval = leftGen.code + rightGen.code + ctx.nullSafeExec(nullable, leftGen.isNull) { | ||
| s""" | ||
| |${ev.isNull} = false; | ||
| |${resultCode} | ||
| |""".stripMargin | ||
| } | ||
| ev.copy(code = | ||
| code""" | ||
| |boolean ${ev.isNull} = true; | ||
| |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; | ||
| |$nullSafeEval | ||
| """.stripMargin | ||
| ) | ||
| } else { | ||
| ev.copy(code = | ||
| code""" | ||
navinvishy marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| |${leftGen.code} | ||
| |${rightGen.code} | ||
| |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; | ||
| |$resultCode | ||
| """.stripMargin, isNull = FalseLiteral) | ||
| } | ||
| } | ||
|
|
||
| override def prettyName: String = "array_prepend" | ||
|
|
||
| override protected def withNewChildrenInternal( | ||
navinvishy marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| newLeft: Expression, newRight: Expression): ArrayPrepend = | ||
| copy(left = newLeft, right = newRight) | ||
|
|
||
| override def dataType: DataType = if (right.nullable) left.dataType.asNullable else left.dataType | ||
|
|
||
| override def checkInputDataTypes(): TypeCheckResult = { | ||
|
||
| (left.dataType, right.dataType) match { | ||
| case (ArrayType(e1, _), e2) if e1.sameType(e2) => TypeCheckResult.TypeCheckSuccess | ||
| case (ArrayType(e1, _), e2) => DataTypeMismatch( | ||
| errorSubClass = "ARRAY_FUNCTION_DIFF_TYPES", | ||
| messageParameters = Map( | ||
| "functionName" -> toSQLId(prettyName), | ||
| "leftType" -> toSQLType(left.dataType), | ||
| "rightType" -> toSQLType(right.dataType), | ||
| "dataType" -> toSQLType(ArrayType) | ||
| )) | ||
| case _ => | ||
| DataTypeMismatch( | ||
| errorSubClass = "UNEXPECTED_INPUT_TYPE", | ||
| messageParameters = Map( | ||
| "paramIndex" -> "0", | ||
| "requiredType" -> toSQLType(ArrayType), | ||
| "inputSql" -> toSQLExpr(left), | ||
| "inputType" -> toSQLType(left.dataType) | ||
| ) | ||
| ) | ||
| } | ||
| } | ||
| override def inputTypes: Seq[AbstractDataType] = { | ||
| (left.dataType, right.dataType) match { | ||
| case (ArrayType(e1, hasNull), e2) => | ||
| TypeCoercion.findTightestCommonType(e1, e2) match { | ||
| case Some(dt) => Seq(ArrayType(dt, hasNull), dt) | ||
| case _ => Seq.empty | ||
| } | ||
| case _ => Seq.empty | ||
| } | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Checks if the two arrays contain at least one common element. | ||
| */ | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.