-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-23914][SQL] Add array_union function #21061
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
dc9d6f0
3019840
8cee6cf
2041ec4
8c2280b
b3a3132
a2c7dd1
5313680
98f8d1f
30ee7fc
cd347e9
d2eaee3
2ddeb06
71b31f0
7e71340
04c97c3
401ca7a
15b953b
f050922
8a27667
e50bc55
7e3f2ef
e5401e7
3e21e48
3c39506
6654742
be9f331
90e84b3
6f721f0
0c0d3ba
4a217bc
f5ebbe8
763a1f8
7b51564
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -450,7 +450,7 @@ public double[] toDoubleArray() { | |
| return values; | ||
| } | ||
|
|
||
| private static UnsafeArrayData fromPrimitiveArray( | ||
| public static UnsafeArrayData fromPrimitiveArray( | ||
| Object arr, int offset, int length, int elementSize) { | ||
| final long headerInBytes = calculateHeaderPortionInBytes(length); | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this logic extracted to
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this thread an answer to this question?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok. |
||
| final long valueRegionInBytes = (long)elementSize * length; | ||
|
|
@@ -463,14 +463,27 @@ private static UnsafeArrayData fromPrimitiveArray( | |
| final long[] data = new long[(int)totalSizeInLongs]; | ||
|
|
||
| Platform.putLong(data, Platform.LONG_ARRAY_OFFSET, length); | ||
| Platform.copyMemory(arr, offset, data, | ||
| Platform.LONG_ARRAY_OFFSET + headerInBytes, valueRegionInBytes); | ||
| if (arr != null) { | ||
| Platform.copyMemory(arr, offset, data, | ||
| Platform.LONG_ARRAY_OFFSET + headerInBytes, valueRegionInBytes); | ||
| } | ||
|
|
||
| UnsafeArrayData result = new UnsafeArrayData(); | ||
| result.pointTo(data, Platform.LONG_ARRAY_OFFSET, (int)totalSizeInLongs * 8); | ||
| return result; | ||
| } | ||
|
|
||
| public static UnsafeArrayData forPrimitiveArray(int offset, int length, int elementSize) { | ||
| return fromPrimitiveArray(null, offset, length, elementSize); | ||
| } | ||
|
|
||
| public static boolean shouldUseGenericArrayData(int elementSize, int length) { | ||
| final long headerInBytes = calculateHeaderPortionInBytes(length); | ||
| final long valueRegionInBytes = (long)elementSize * length; | ||
| final long totalSizeInLongs = (headerInBytes + valueRegionInBytes + 7) / 8; | ||
| return totalSizeInLongs > Integer.MAX_VALUE / 8; | ||
| } | ||
|
|
||
| public static UnsafeArrayData fromPrimitiveArray(boolean[] arr) { | ||
| return fromPrimitiveArray(arr, Platform.BOOLEAN_ARRAY_OFFSET, arr.length, 1); | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3261,3 +3261,322 @@ case class ArrayDistinct(child: Expression) | |
|
|
||
| override def prettyName: String = "array_distinct" | ||
| } | ||
|
|
||
| /** | ||
| * Will become common base class for [[ArrayUnion]], ArrayIntersect, and ArrayExcept. | ||
| */ | ||
| abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast { | ||
| override def dataType: DataType = { | ||
| val dataTypes = children.map(_.dataType.asInstanceOf[ArrayType]) | ||
| ArrayType(elementType, dataTypes.exists(_.containsNull)) | ||
| } | ||
|
|
||
| override def checkInputDataTypes(): TypeCheckResult = { | ||
| val typeCheckResult = super.checkInputDataTypes() | ||
| if (typeCheckResult.isSuccess) { | ||
| TypeUtils.checkForOrderingExpr(dataType.asInstanceOf[ArrayType].elementType, | ||
| s"function $prettyName") | ||
| } else { | ||
| typeCheckResult | ||
| } | ||
| } | ||
|
|
||
| @transient protected lazy val ordering: Ordering[Any] = | ||
| TypeUtils.getInterpretedOrdering(elementType) | ||
|
|
||
| @transient protected lazy val elementTypeSupportEquals = elementType match { | ||
| case BinaryType => false | ||
| case _: AtomicType => true | ||
| case _ => false | ||
| } | ||
| } | ||
|
|
||
| object ArraySetLike { | ||
| def throwUnionLengthOverflowException(length: Int): Unit = { | ||
| throw new RuntimeException(s"Unsuccessful try to union arrays with $length " + | ||
| s"elements due to exceeding the array size limit " + | ||
| s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.") | ||
| } | ||
| } | ||
|
|
||
|
|
||
| /** | ||
| * Returns an array of the elements in the union of x and y, without duplicates | ||
| */ | ||
| @ExpressionDescription( | ||
| usage = """ | ||
| _FUNC_(array1, array2) - Returns an array of the elements in the union of array1 and array2, | ||
| without duplicates. | ||
| """, | ||
| examples = """ | ||
| Examples: | ||
| > SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5)); | ||
| array(1, 2, 3, 5) | ||
| """, | ||
| since = "2.4.0") | ||
| case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike { | ||
| var hsInt: OpenHashSet[Int] = _ | ||
| var hsLong: OpenHashSet[Long] = _ | ||
|
|
||
| def assignInt(array: ArrayData, idx: Int, resultArray: ArrayData, pos: Int): Boolean = { | ||
| val elem = array.getInt(idx) | ||
| if (!hsInt.contains(elem)) { | ||
| if (resultArray != null) { | ||
| resultArray.setInt(pos, elem) | ||
| } | ||
| hsInt.add(elem) | ||
| true | ||
| } else { | ||
| false | ||
| } | ||
| } | ||
|
|
||
| def assignLong(array: ArrayData, idx: Int, resultArray: ArrayData, pos: Int): Boolean = { | ||
| val elem = array.getLong(idx) | ||
| if (!hsLong.contains(elem)) { | ||
| if (resultArray != null) { | ||
| resultArray.setLong(pos, elem) | ||
| } | ||
| hsLong.add(elem) | ||
| true | ||
| } else { | ||
| false | ||
| } | ||
| } | ||
|
|
||
| def evalIntLongPrimitiveType( | ||
| array1: ArrayData, | ||
| array2: ArrayData, | ||
| resultArray: ArrayData, | ||
| isLongType: Boolean): Int = { | ||
| // store elements into resultArray | ||
| var nullElementSize = 0 | ||
| var pos = 0 | ||
| Seq(array1, array2).foreach { array => | ||
| var i = 0 | ||
| while (i < array.numElements()) { | ||
| val size = if (!isLongType) hsInt.size else hsLong.size | ||
| if (size + nullElementSize > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { | ||
| ArraySetLike.throwUnionLengthOverflowException(size) | ||
| } | ||
| if (array.isNullAt(i)) { | ||
| if (nullElementSize == 0) { | ||
| if (resultArray != null) { | ||
| resultArray.setNullAt(pos) | ||
| } | ||
| pos += 1 | ||
| nullElementSize = 1 | ||
| } | ||
| } else { | ||
| val assigned = if (!isLongType) { | ||
| assignInt(array, i, resultArray, pos) | ||
| } else { | ||
| assignLong(array, i, resultArray, pos) | ||
| } | ||
| if (assigned) { | ||
| pos += 1 | ||
| } | ||
| } | ||
| i += 1 | ||
| } | ||
| } | ||
| pos | ||
| } | ||
|
|
||
| override def nullSafeEval(input1: Any, input2: Any): Any = { | ||
| val array1 = input1.asInstanceOf[ArrayData] | ||
| val array2 = input2.asInstanceOf[ArrayData] | ||
|
|
||
| if (elementTypeSupportEquals) { | ||
| elementType match { | ||
| case IntegerType => | ||
| // avoid boxing of primitive int array elements | ||
| // calculate result array size | ||
| hsInt = new OpenHashSet[Int] | ||
| val elements = evalIntLongPrimitiveType(array1, array2, null, false) | ||
| hsInt = new OpenHashSet[Int] | ||
| val resultArray = if (UnsafeArrayData.shouldUseGenericArrayData( | ||
| IntegerType.defaultSize, elements)) { | ||
| new GenericArrayData(new Array[Any](elements)) | ||
| } else { | ||
| UnsafeArrayData.forPrimitiveArray( | ||
| Platform.INT_ARRAY_OFFSET, elements, IntegerType.defaultSize) | ||
| } | ||
| evalIntLongPrimitiveType(array1, array2, resultArray, false) | ||
| resultArray | ||
| case LongType => | ||
| // avoid boxing of primitive long array elements | ||
| // calculate result array size | ||
| hsLong = new OpenHashSet[Long] | ||
| val elements = evalIntLongPrimitiveType(array1, array2, null, true) | ||
| hsLong = new OpenHashSet[Long] | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Once we obtain unique elements of two arrays in the hash set, can't we get final array elements from it directly instead of scanning two arrays again?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It could be. Originally, I took that approach.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see, though I think there will be some performance issue. |
||
| val resultArray = if (UnsafeArrayData.shouldUseGenericArrayData( | ||
| LongType.defaultSize, elements)) { | ||
| new GenericArrayData(new Array[Any](elements)) | ||
| } else { | ||
| UnsafeArrayData.forPrimitiveArray( | ||
| Platform.LONG_ARRAY_OFFSET, elements, LongType.defaultSize) | ||
| } | ||
| evalIntLongPrimitiveType(array1, array2, resultArray, true) | ||
| resultArray | ||
| case _ => | ||
| val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] | ||
| val hs = new OpenHashSet[Any] | ||
| var foundNullElement = false | ||
| Seq(array1, array2).foreach { array => | ||
| var i = 0 | ||
| while (i < array.numElements()) { | ||
| if (array.isNullAt(i)) { | ||
| if (!foundNullElement) { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: This two
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it is not easy since we want to do nothing if |
||
| arrayBuffer += null | ||
| foundNullElement = true | ||
| } | ||
| } else { | ||
| val elem = array.get(i, elementType) | ||
| if (!hs.contains(elem)) { | ||
| if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { | ||
| ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.size) | ||
| } | ||
| arrayBuffer += elem | ||
| hs.add(elem) | ||
| } | ||
| } | ||
| i += 1 | ||
| } | ||
| } | ||
| new GenericArrayData(arrayBuffer) | ||
| } | ||
| } else { | ||
| ArrayUnion.unionOrdering(array1, array2, elementType, ordering) | ||
| } | ||
| } | ||
|
|
||
| override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { | ||
| val i = ctx.freshName("i") | ||
| val pos = ctx.freshName("pos") | ||
| val value = ctx.freshName("value") | ||
| val size = ctx.freshName("size") | ||
| val (postFix, openHashElementType, getter, setter, javaTypeName, castOp, arrayBuilder) = | ||
| if (elementTypeSupportEquals) { | ||
| elementType match { | ||
| case ByteType | ShortType | IntegerType | LongType => | ||
| val ptName = CodeGenerator.primitiveTypeName(elementType) | ||
| val unsafeArray = ctx.freshName("unsafeArray") | ||
| (if (elementType == LongType) s"$$mcJ$$sp" else s"$$mcI$$sp", | ||
| if (elementType == LongType) "Long" else "Int", | ||
| s"get$ptName($i)", s"set$ptName($pos, $value)", CodeGenerator.javaType(elementType), | ||
| if (elementType == LongType) "(long)" else "(int)", | ||
| s""" | ||
| |${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")} | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks like we don't automatically choose to use
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Your comment is correct. It would be good to address this choice in another PR to update
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you mean a refactoring around the usage of
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I mean a refactoring the usage of |
||
| |${ev.value} = $unsafeArray; | ||
| """.stripMargin) | ||
| case _ => | ||
| val genericArrayData = classOf[GenericArrayData].getName | ||
| val et = ctx.addReferenceObj("elementType", elementType) | ||
| ("", "Object", | ||
| s"get($i, $et)", s"update($pos, $value)", "Object", "", | ||
| s"${ev.value} = new $genericArrayData(new Object[$size]);") | ||
| } | ||
| } else { | ||
| ("", "", "", "", "", "", "") | ||
| } | ||
|
|
||
| nullSafeCodeGen(ctx, ev, (array1, array2) => { | ||
| if (openHashElementType != "") { | ||
| // Here, we ensure elementTypeSupportEquals is true | ||
| val foundNullElement = ctx.freshName("foundNullElement") | ||
| val openHashSet = classOf[OpenHashSet[_]].getName | ||
| val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$openHashElementType()" | ||
| val hs = ctx.freshName("hs") | ||
| val arrayData = classOf[ArrayData].getName | ||
| val arrays = ctx.freshName("arrays") | ||
| val array = ctx.freshName("array") | ||
| val arrayDataIdx = ctx.freshName("arrayDataIdx") | ||
| s""" | ||
| |$openHashSet $hs = new $openHashSet$postFix($classTag); | ||
| |boolean $foundNullElement = false; | ||
| |$arrayData[] $arrays = new $arrayData[]{$array1, $array2}; | ||
| |for (int $arrayDataIdx = 0; $arrayDataIdx < 2; $arrayDataIdx++) { | ||
| | $arrayData $array = $arrays[$arrayDataIdx]; | ||
| | for (int $i = 0; $i < $array.numElements(); $i++) { | ||
| | if ($array.isNullAt($i)) { | ||
| | $foundNullElement = true; | ||
| | } else { | ||
| | $hs.add$postFix($array.$getter); | ||
| | } | ||
| | } | ||
| |} | ||
| |int $size = $hs.size() + ($foundNullElement ? 1 : 0); | ||
| |$arrayBuilder | ||
| |$hs = new $openHashSet$postFix($classTag); | ||
| |$foundNullElement = false; | ||
| |int $pos = 0; | ||
| |for (int $arrayDataIdx = 0; $arrayDataIdx < 2; $arrayDataIdx++) { | ||
| | $arrayData $array = $arrays[$arrayDataIdx]; | ||
| | for (int $i = 0; $i < $array.numElements(); $i++) { | ||
| | if ($array.isNullAt($i)) { | ||
| | if (!$foundNullElement) { | ||
| | ${ev.value}.setNullAt($pos++); | ||
| | $foundNullElement = true; | ||
| | } | ||
| | } else { | ||
| | $javaTypeName $value = $array.$getter; | ||
| | if (!$hs.contains($castOp $value)) { | ||
| | $hs.add$postFix($value); | ||
| | ${ev.value}.$setter; | ||
| | $pos++; | ||
| | } | ||
| | } | ||
| | } | ||
| |} | ||
| """.stripMargin | ||
| } else { | ||
| val arrayUnion = classOf[ArrayUnion].getName | ||
| val et = ctx.addReferenceObj("elementTypeUnion", elementType) | ||
| val order = ctx.addReferenceObj("orderingUnion", ordering) | ||
| val method = "unionOrdering" | ||
| s"${ev.value} = $arrayUnion$$.MODULE$$.$method($array1, $array2, $et, $order);" | ||
| } | ||
| }) | ||
| } | ||
|
|
||
| override def prettyName: String = "array_union" | ||
| } | ||
|
|
||
| object ArrayUnion { | ||
| def unionOrdering( | ||
| array1: ArrayData, | ||
| array2: ArrayData, | ||
| elementType: DataType, | ||
| ordering: Ordering[Any]): ArrayData = { | ||
| val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] | ||
| var alreadyIncludeNull = false | ||
| Seq(array1, array2).foreach(_.foreach(elementType, (_, elem) => { | ||
| var found = false | ||
| if (elem == null) { | ||
| if (alreadyIncludeNull) { | ||
| found = true | ||
| } else { | ||
| alreadyIncludeNull = true | ||
| } | ||
| } else { | ||
| // check elem is already stored in arrayBuffer or not? | ||
| var j = 0 | ||
| while (!found && j < arrayBuffer.size) { | ||
| val va = arrayBuffer(j) | ||
| if (va != null && ordering.equiv(va, elem)) { | ||
| found = true | ||
| } | ||
| j = j + 1 | ||
| } | ||
| } | ||
| if (!found) { | ||
| if (arrayBuffer.length > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { | ||
| ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.length) | ||
| } | ||
| arrayBuffer += elem | ||
| } | ||
| })) | ||
| new GenericArrayData(arrayBuffer) | ||
| } | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the array of col1 contains duplicate elements itself, what it does? de-duplicate them too?
E.g.,
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After reading the code, seems it de-duplicates all elements from two arrays. Is this behavior the same as Presto?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will add the tests for duplication.
Yes, this will de-duplicate. I think that it is the same behavior as Presto.