-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-23736][SQL] Extending the concat function to support array columns #20858
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
282e724
aa5a089
90d3ab7
bb46c3d
11205af
753499d
2efdd77
fd84bee
116f91f
e199ac5
067c2db
090929f
8abd1a8
367ee22
6bb33e6
57b250c
944e0c9
7f5124b
0201e4b
600ae89
f2a67e8
8a125d9
5a4cc8c
f7bdcf7
36d5d25
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,7 +23,9 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult | |
| import org.apache.spark.sql.catalyst.expressions.codegen._ | ||
| import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils} | ||
| import org.apache.spark.sql.types._ | ||
| import org.apache.spark.unsafe.types.UTF8String | ||
| import org.apache.spark.unsafe.Platform | ||
| import org.apache.spark.unsafe.array.ByteArrayMethods | ||
| import org.apache.spark.unsafe.types.{ByteArray, UTF8String} | ||
|
|
||
| /** | ||
| * Given an array or map, returns its size. Returns -1 if null. | ||
|
|
@@ -665,3 +667,219 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti | |
|
|
||
| override def prettyName: String = "element_at" | ||
| } | ||
|
|
||
| /** | ||
| * Concatenates multiple input columns together into a single column. | ||
| * The function works with strings, binary and compatible array columns. | ||
| */ | ||
| @ExpressionDescription( | ||
| usage = "_FUNC_(col1, col2, ..., colN) - Returns the concatenation of col1, col2, ..., colN.", | ||
| examples = """ | ||
| Examples: | ||
| > SELECT _FUNC_('Spark', 'SQL'); | ||
| SparkSQL | ||
| > SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6)); | ||
| | [1,2,3,4,5,6] | ||
| """) | ||
| case class Concat(children: Seq[Expression]) extends Expression { | ||
|
|
||
| private val MAX_ARRAY_LENGTH: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH | ||
|
|
||
| val allowedTypes = Seq(StringType, BinaryType, ArrayType) | ||
|
|
||
| override def checkInputDataTypes(): TypeCheckResult = { | ||
| if (children.isEmpty) { | ||
| TypeCheckResult.TypeCheckSuccess | ||
| } else { | ||
| val childTypes = children.map(_.dataType) | ||
| if (childTypes.exists(tpe => !allowedTypes.exists(_.acceptsType(tpe)))) { | ||
| return TypeCheckResult.TypeCheckFailure( | ||
| s"input to function $prettyName should have been StringType, BinaryType or ArrayType," + | ||
| s" but it's " + childTypes.map(_.simpleString).mkString("[", ", ", "]")) | ||
| } | ||
| TypeUtils.checkForSameTypeInputExpr(childTypes, s"function $prettyName") | ||
| } | ||
| } | ||
|
|
||
| override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType) | ||
|
|
||
| lazy val javaType: String = CodeGenerator.javaType(dataType) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can move this into
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point! But I think it would be better to reuse |
||
|
|
||
| override def nullable: Boolean = children.exists(_.nullable) | ||
|
|
||
| override def foldable: Boolean = children.forall(_.foldable) | ||
|
|
||
| override def eval(input: InternalRow): Any = dataType match { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so this pattern match will probably cause significant regression in the interpreted (non-codegen) mode, due to the way scala pattern matching is implemented.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| case BinaryType => | ||
| val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]]) | ||
| ByteArray.concat(inputs: _*) | ||
| case StringType => | ||
| val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) | ||
| UTF8String.concat(inputs : _*) | ||
| case ArrayType(elementType, _) => | ||
| val inputs = children.toStream.map(_.eval(input)) | ||
| if (inputs.contains(null)) { | ||
| null | ||
| } else { | ||
| val arrayData = inputs.map(_.asInstanceOf[ArrayData]) | ||
| val numberOfElements = arrayData.foldLeft(0L)((sum, ad) => sum + ad.numElements()) | ||
| if (numberOfElements > MAX_ARRAY_LENGTH) { | ||
| throw new RuntimeException(s"Unsuccessful try to concat arrays with $numberOfElements" + | ||
| s" elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.") | ||
| } | ||
| val finalData = new Array[AnyRef](numberOfElements.toInt) | ||
| var position = 0 | ||
| for(ad <- arrayData) { | ||
| val arr = ad.toObjectArray(elementType) | ||
| Array.copy(arr, 0, finalData, position, arr.length) | ||
| position += arr.length | ||
| } | ||
| new GenericArrayData(finalData) | ||
| } | ||
| } | ||
|
|
||
| override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { | ||
| val evals = children.map(_.genCode(ctx)) | ||
| val args = ctx.freshName("args") | ||
|
|
||
| val inputs = evals.zipWithIndex.map { case (eval, index) => | ||
| s""" | ||
| ${eval.code} | ||
| if (!${eval.isNull}) { | ||
| $args[$index] = ${eval.value}; | ||
| } | ||
| """ | ||
| } | ||
|
|
||
| val (concatenator, initCode) = dataType match { | ||
| case BinaryType => | ||
| (classOf[ByteArray].getName, s"byte[][] $args = new byte[${evals.length}][];") | ||
| case StringType => | ||
| ("UTF8String", s"UTF8String[] $args = new UTF8String[${evals.length}];") | ||
| case ArrayType(elementType, _) => | ||
| val arrayConcatClass = if (CodeGenerator.isPrimitiveType(elementType)) { | ||
| genCodeForPrimitiveArrays(ctx, elementType) | ||
| } else { | ||
| genCodeForNonPrimitiveArrays(ctx, elementType) | ||
| } | ||
| (arrayConcatClass, s"ArrayData[] $args = new ArrayData[${evals.length}];") | ||
| } | ||
| val codes = ctx.splitExpressionsWithCurrentInputs( | ||
| expressions = inputs, | ||
| funcName = "valueConcat", | ||
| extraArguments = (s"$javaType[]", args) :: Nil) | ||
| ev.copy(s""" | ||
| $initCode | ||
| $codes | ||
| $javaType ${ev.value} = $concatenator.concat($args); | ||
| boolean ${ev.isNull} = ${ev.value} == null; | ||
| """) | ||
| } | ||
|
|
||
| private def genCodeForNumberOfElements(ctx: CodegenContext) : (String, String) = { | ||
| val numElements = ctx.freshName("numElements") | ||
| val code = s""" | ||
| |long $numElements = 0L; | ||
| |for (int z = 0; z < ${children.length}; z++) { | ||
| | $numElements += args[z].numElements(); | ||
| |} | ||
| |if ($numElements > $MAX_ARRAY_LENGTH) { | ||
| | throw new RuntimeException("Unsuccessful try to concat arrays with " + $numElements + | ||
| | " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH."); | ||
| |} | ||
| """.stripMargin | ||
|
|
||
| (code, numElements) | ||
| } | ||
|
|
||
| private def nullArgumentProtection() : String = { | ||
| if (nullable) { | ||
| s""" | ||
| |for (int z = 0; z < ${children.length}; z++) { | ||
| | if (args[z] == null) return null; | ||
| |} | ||
| """.stripMargin | ||
| } else { | ||
| "" | ||
| } | ||
| } | ||
|
|
||
| private def genCodeForPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = { | ||
| val arrayName = ctx.freshName("array") | ||
| val arraySizeName = ctx.freshName("size") | ||
| val counter = ctx.freshName("counter") | ||
| val arrayData = ctx.freshName("arrayData") | ||
|
|
||
| val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx) | ||
|
|
||
| val unsafeArraySizeInBytes = s""" | ||
| |long $arraySizeName = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( | ||
| | $numElemName, | ||
| | ${elementType.defaultSize}); | ||
| |if ($arraySizeName > $MAX_ARRAY_LENGTH) { | ||
| | throw new RuntimeException("Unsuccessful try to concat arrays with " + $arraySizeName + | ||
| | " bytes of data due to exceeding the limit $MAX_ARRAY_LENGTH bytes" + | ||
| | " for UnsafeArrayData."); | ||
| |} | ||
| """.stripMargin | ||
| val baseOffset = Platform.BYTE_ARRAY_OFFSET | ||
| val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) | ||
|
|
||
| s""" | ||
| |new Object() { | ||
| | public ArrayData concat($javaType[] args) { | ||
| | ${nullArgumentProtection()} | ||
| | $numElemCode | ||
| | $unsafeArraySizeInBytes | ||
| | byte[] $arrayName = new byte[(int)$arraySizeName]; | ||
| | UnsafeArrayData $arrayData = new UnsafeArrayData(); | ||
| | Platform.putLong($arrayName, $baseOffset, $numElemName); | ||
| | $arrayData.pointTo($arrayName, $baseOffset, (int)$arraySizeName); | ||
| | int $counter = 0; | ||
| | for (int y = 0; y < ${children.length}; y++) { | ||
| | for (int z = 0; z < args[y].numElements(); z++) { | ||
| | if (args[y].isNullAt(z)) { | ||
| | $arrayData.setNullAt($counter); | ||
| | } else { | ||
| | $arrayData.set$primitiveValueTypeName( | ||
| | $counter, | ||
| | ${CodeGenerator.getValue(s"args[y]", elementType, "z")} | ||
| | ); | ||
| | } | ||
| | $counter++; | ||
| | } | ||
| | } | ||
| | return $arrayData; | ||
| | } | ||
| |}""".stripMargin.stripPrefix("\n") | ||
| } | ||
|
|
||
| private def genCodeForNonPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = { | ||
| val genericArrayClass = classOf[GenericArrayData].getName | ||
| val arrayData = ctx.freshName("arrayObjects") | ||
| val counter = ctx.freshName("counter") | ||
|
|
||
| val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx) | ||
|
|
||
| s""" | ||
| |new Object() { | ||
| | public ArrayData concat($javaType[] args) { | ||
| | ${nullArgumentProtection()} | ||
| | $numElemCode | ||
| | Object[] $arrayData = new Object[(int)$numElemName]; | ||
| | int $counter = 0; | ||
| | for (int y = 0; y < ${children.length}; y++) { | ||
| | for (int z = 0; z < args[y].numElements(); z++) { | ||
| | $arrayData[$counter] = ${CodeGenerator.getValue(s"args[y]", elementType, "z")}; | ||
| | $counter++; | ||
| | } | ||
| | } | ||
| | return new $genericArrayClass($arrayData); | ||
| | } | ||
| |}""".stripMargin.stripPrefix("\n") | ||
| } | ||
|
|
||
| override def toString: String = s"concat(${children.mkString(", ")})" | ||
|
|
||
| override def sql: String = s"concat(${children.map(_.sql).mkString(", ")})" | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did we move this down .. ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The whole file is divide into sections according to groups of functions. Based on @gatorsmile's suggestion, the concat function should be categorized as a collection function. So I moved the function to comply with the file structure.