Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
dc9d6f0
initial commit
kiszk Apr 13, 2018
3019840
update description
kiszk Apr 13, 2018
8cee6cf
fix test failure
kiszk Apr 13, 2018
2041ec4
address review comments
kiszk Apr 17, 2018
8c2280b
introduce ArraySetUtils to reuse code among array_union/array_interse…
kiszk Apr 17, 2018
b3a3132
fix python test failure
kiszk Apr 18, 2018
a2c7dd1
fix python test failure
kiszk Apr 18, 2018
5313680
simplification
kiszk Apr 18, 2018
98f8d1f
fix pyspark test failure
kiszk Apr 19, 2018
30ee7fc
address review comments
kiszk Apr 20, 2018
cd347e9
add new tests based on review comment
kiszk Apr 20, 2018
d2eaee3
fix mistakes in rebase
kiszk Apr 20, 2018
2ddeb06
fix unexpected changes
kiszk Apr 20, 2018
71b31f0
merge changes in #21103
kiszk Apr 20, 2018
7e71340
use GenericArrayData if UnsafeArrayData cannot be used
kiszk May 4, 2018
04c97c3
use BinaryArrayExpressionWithImplicitCast
kiszk May 4, 2018
401ca7a
update test cases
kiszk May 4, 2018
15b953b
rebase with master
kiszk May 17, 2018
f050922
support complex types
kiszk May 18, 2018
8a27667
add test cases with duplication in an array
kiszk May 19, 2018
e50bc55
rebase with master
kiszk Jun 1, 2018
7e3f2ef
address review comments
kiszk Jun 1, 2018
e5401e7
address review comment
kiszk Jun 1, 2018
3e21e48
keep the order of input array elements
kiszk Jun 10, 2018
3c39506
address review comments
kiszk Jun 20, 2018
6654742
fix scala style error
kiszk Jun 20, 2018
be9f331
address review comment
kiszk Jun 20, 2018
90e84b3
address review comments
kiszk Jun 22, 2018
6f721f0
address review comments
kiszk Jun 22, 2018
0c0d3ba
address review comments
kiszk Jul 8, 2018
4a217bc
cleanup
kiszk Jul 8, 2018
f5ebbe8
eliminate duplicated code
kiszk Jul 8, 2018
763a1f8
address review comments
kiszk Jul 9, 2018
7b51564
address review comment
kiszk Jul 11, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2013,6 +2013,25 @@ def array_distinct(col):
return Column(sc._jvm.functions.array_distinct(_to_java_column(col)))


@ignore_unicode_prefix
@since(2.4)
def array_union(col1, col2):
"""
Collection function: returns an array of the elements in the union of col1 and col2,
Copy link
Member

@viirya viirya Jul 9, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the array of col1 contains duplicate elements itself, what it does? de-duplicate them too?

E.g.,

df = spark.createDataFrame([Row(c1=["b", "a", "c", "c"], c2=["c", "d", "a", "f"])])
df.select(array_union(df.c1, df.c2)).collect()

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After reading the code, seems it de-duplicates all elements from two arrays. Is this behavior the same as Presto?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will add the tests for duplication.
Yes, this will de-duplicate. I think that it is the same behavior as Presto.

without duplicates.

:param col1: name of column containing array
:param col2: name of column containing array

>>> from pyspark.sql import Row
>>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])])
>>> df.select(array_union(df.c1, df.c2)).collect()
[Row(array_union(c1, c2)=[u'b', u'a', u'c', u'd', u'f'])]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.array_union(_to_java_column(col1), _to_java_column(col2)))


@since(1.4)
def explode(col):
"""Returns a new row for each element in the given array or map.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ public double[] toDoubleArray() {
return values;
}

private static UnsafeArrayData fromPrimitiveArray(
public static UnsafeArrayData fromPrimitiveArray(
Object arr, int offset, int length, int elementSize) {
final long headerInBytes = calculateHeaderPortionInBytes(length);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this logic extracted to useGenericArrayData? If so, can we re-use it by calling the method here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this thread an answer to this question?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok.

final long valueRegionInBytes = (long)elementSize * length;
Expand All @@ -463,14 +463,27 @@ private static UnsafeArrayData fromPrimitiveArray(
final long[] data = new long[(int)totalSizeInLongs];

Platform.putLong(data, Platform.LONG_ARRAY_OFFSET, length);
Platform.copyMemory(arr, offset, data,
Platform.LONG_ARRAY_OFFSET + headerInBytes, valueRegionInBytes);
if (arr != null) {
Platform.copyMemory(arr, offset, data,
Platform.LONG_ARRAY_OFFSET + headerInBytes, valueRegionInBytes);
}

UnsafeArrayData result = new UnsafeArrayData();
result.pointTo(data, Platform.LONG_ARRAY_OFFSET, (int)totalSizeInLongs * 8);
return result;
}

public static UnsafeArrayData forPrimitiveArray(int offset, int length, int elementSize) {
return fromPrimitiveArray(null, offset, length, elementSize);
}

public static boolean shouldUseGenericArrayData(int elementSize, int length) {
final long headerInBytes = calculateHeaderPortionInBytes(length);
final long valueRegionInBytes = (long)elementSize * length;
final long totalSizeInLongs = (headerInBytes + valueRegionInBytes + 7) / 8;
return totalSizeInLongs > Integer.MAX_VALUE / 8;
}

public static UnsafeArrayData fromPrimitiveArray(boolean[] arr) {
return fromPrimitiveArray(arr, Platform.BOOLEAN_ARRAY_OFFSET, arr.length, 1);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,7 @@ object FunctionRegistry {
expression[ArrayJoin]("array_join"),
expression[ArrayPosition]("array_position"),
expression[ArraySort]("array_sort"),
expression[ArrayUnion]("array_union"),
expression[CreateMap]("map"),
expression[CreateNamedStruct]("named_struct"),
expression[ElementAt]("element_at"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3261,3 +3261,322 @@ case class ArrayDistinct(child: Expression)

override def prettyName: String = "array_distinct"
}

/**
* Will become common base class for [[ArrayUnion]], ArrayIntersect, and ArrayExcept.
*/
abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast {
override def dataType: DataType = {
val dataTypes = children.map(_.dataType.asInstanceOf[ArrayType])
ArrayType(elementType, dataTypes.exists(_.containsNull))
}

override def checkInputDataTypes(): TypeCheckResult = {
val typeCheckResult = super.checkInputDataTypes()
if (typeCheckResult.isSuccess) {
TypeUtils.checkForOrderingExpr(dataType.asInstanceOf[ArrayType].elementType,
s"function $prettyName")
} else {
typeCheckResult
}
}

@transient protected lazy val ordering: Ordering[Any] =
TypeUtils.getInterpretedOrdering(elementType)

@transient protected lazy val elementTypeSupportEquals = elementType match {
case BinaryType => false
case _: AtomicType => true
case _ => false
}
}

object ArraySetLike {
def throwUnionLengthOverflowException(length: Int): Unit = {
throw new RuntimeException(s"Unsuccessful try to union arrays with $length " +
s"elements due to exceeding the array size limit " +
s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.")
}
}


/**
* Returns an array of the elements in the union of x and y, without duplicates
*/
@ExpressionDescription(
usage = """
_FUNC_(array1, array2) - Returns an array of the elements in the union of array1 and array2,
without duplicates.
""",
examples = """
Examples:
> SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5));
array(1, 2, 3, 5)
""",
since = "2.4.0")
case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike {
var hsInt: OpenHashSet[Int] = _
var hsLong: OpenHashSet[Long] = _

def assignInt(array: ArrayData, idx: Int, resultArray: ArrayData, pos: Int): Boolean = {
val elem = array.getInt(idx)
if (!hsInt.contains(elem)) {
if (resultArray != null) {
resultArray.setInt(pos, elem)
}
hsInt.add(elem)
true
} else {
false
}
}

def assignLong(array: ArrayData, idx: Int, resultArray: ArrayData, pos: Int): Boolean = {
val elem = array.getLong(idx)
if (!hsLong.contains(elem)) {
if (resultArray != null) {
resultArray.setLong(pos, elem)
}
hsLong.add(elem)
true
} else {
false
}
}

def evalIntLongPrimitiveType(
array1: ArrayData,
array2: ArrayData,
resultArray: ArrayData,
isLongType: Boolean): Int = {
// store elements into resultArray
var nullElementSize = 0
var pos = 0
Seq(array1, array2).foreach { array =>
var i = 0
while (i < array.numElements()) {
val size = if (!isLongType) hsInt.size else hsLong.size
if (size + nullElementSize > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
ArraySetLike.throwUnionLengthOverflowException(size)
}
if (array.isNullAt(i)) {
if (nullElementSize == 0) {
if (resultArray != null) {
resultArray.setNullAt(pos)
}
pos += 1
nullElementSize = 1
}
} else {
val assigned = if (!isLongType) {
assignInt(array, i, resultArray, pos)
} else {
assignLong(array, i, resultArray, pos)
}
if (assigned) {
pos += 1
}
}
i += 1
}
}
pos
}

override def nullSafeEval(input1: Any, input2: Any): Any = {
val array1 = input1.asInstanceOf[ArrayData]
val array2 = input2.asInstanceOf[ArrayData]

if (elementTypeSupportEquals) {
elementType match {
case IntegerType =>
// avoid boxing of primitive int array elements
// calculate result array size
hsInt = new OpenHashSet[Int]
val elements = evalIntLongPrimitiveType(array1, array2, null, false)
hsInt = new OpenHashSet[Int]
val resultArray = if (UnsafeArrayData.shouldUseGenericArrayData(
IntegerType.defaultSize, elements)) {
new GenericArrayData(new Array[Any](elements))
} else {
UnsafeArrayData.forPrimitiveArray(
Platform.INT_ARRAY_OFFSET, elements, IntegerType.defaultSize)
}
evalIntLongPrimitiveType(array1, array2, resultArray, false)
resultArray
case LongType =>
// avoid boxing of primitive long array elements
// calculate result array size
hsLong = new OpenHashSet[Long]
val elements = evalIntLongPrimitiveType(array1, array2, null, true)
hsLong = new OpenHashSet[Long]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once we obtain unique elements of two arrays in the hash set, can't we get final array elements from it directly instead of scanning two arrays again?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could be. Originally, I took that approach.
After discussed with @ueshin, I decided to generate a result array from the original arrays instead of the hash. This is because we generate a result array in a unique deterministic order among the different paths in array_union.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, though I think there will be some performance issue.

val resultArray = if (UnsafeArrayData.shouldUseGenericArrayData(
LongType.defaultSize, elements)) {
new GenericArrayData(new Array[Any](elements))
} else {
UnsafeArrayData.forPrimitiveArray(
Platform.LONG_ARRAY_OFFSET, elements, LongType.defaultSize)
}
evalIntLongPrimitiveType(array1, array2, resultArray, true)
resultArray
case _ =>
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
val hs = new OpenHashSet[Any]
var foundNullElement = false
Seq(array1, array2).foreach { array =>
var i = 0
while (i < array.numElements()) {
if (array.isNullAt(i)) {
if (!foundNullElement) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This two if can be combined?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is not easy since we want to do nothing if array.isNullAt(i) && foundNullElement is true.

arrayBuffer += null
foundNullElement = true
}
} else {
val elem = array.get(i, elementType)
if (!hs.contains(elem)) {
if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.size)
}
arrayBuffer += elem
hs.add(elem)
}
}
i += 1
}
}
new GenericArrayData(arrayBuffer)
}
} else {
ArrayUnion.unionOrdering(array1, array2, elementType, ordering)
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val i = ctx.freshName("i")
val pos = ctx.freshName("pos")
val value = ctx.freshName("value")
val size = ctx.freshName("size")
val (postFix, openHashElementType, getter, setter, javaTypeName, castOp, arrayBuilder) =
if (elementTypeSupportEquals) {
elementType match {
case ByteType | ShortType | IntegerType | LongType =>
val ptName = CodeGenerator.primitiveTypeName(elementType)
val unsafeArray = ctx.freshName("unsafeArray")
(if (elementType == LongType) s"$$mcJ$$sp" else s"$$mcI$$sp",
if (elementType == LongType) "Long" else "Int",
s"get$ptName($i)", s"set$ptName($pos, $value)", CodeGenerator.javaType(elementType),
if (elementType == LongType) "(long)" else "(int)",
s"""
|${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like we don't automatically choose to use GenericArrayData as the same as interpreted path?

Copy link
Member Author

@kiszk kiszk Jul 11, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your comment is correct. It would be good to address this choice in another PR to update ctx.createUnsafeArray.
cc: @ueshin

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean a refactoring around the usage of createUnsafeArray through new collection functions in another PR? If so, I'm okay with doing it in another PR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I mean a refactoring the usage of createUnsafeArray thru new collection functions.

|${ev.value} = $unsafeArray;
""".stripMargin)
case _ =>
val genericArrayData = classOf[GenericArrayData].getName
val et = ctx.addReferenceObj("elementType", elementType)
("", "Object",
s"get($i, $et)", s"update($pos, $value)", "Object", "",
s"${ev.value} = new $genericArrayData(new Object[$size]);")
}
} else {
("", "", "", "", "", "", "")
}

nullSafeCodeGen(ctx, ev, (array1, array2) => {
if (openHashElementType != "") {
// Here, we ensure elementTypeSupportEquals is true
val foundNullElement = ctx.freshName("foundNullElement")
val openHashSet = classOf[OpenHashSet[_]].getName
val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$openHashElementType()"
val hs = ctx.freshName("hs")
val arrayData = classOf[ArrayData].getName
val arrays = ctx.freshName("arrays")
val array = ctx.freshName("array")
val arrayDataIdx = ctx.freshName("arrayDataIdx")
s"""
|$openHashSet $hs = new $openHashSet$postFix($classTag);
|boolean $foundNullElement = false;
|$arrayData[] $arrays = new $arrayData[]{$array1, $array2};
|for (int $arrayDataIdx = 0; $arrayDataIdx < 2; $arrayDataIdx++) {
| $arrayData $array = $arrays[$arrayDataIdx];
| for (int $i = 0; $i < $array.numElements(); $i++) {
| if ($array.isNullAt($i)) {
| $foundNullElement = true;
| } else {
| $hs.add$postFix($array.$getter);
| }
| }
|}
|int $size = $hs.size() + ($foundNullElement ? 1 : 0);
|$arrayBuilder
|$hs = new $openHashSet$postFix($classTag);
|$foundNullElement = false;
|int $pos = 0;
|for (int $arrayDataIdx = 0; $arrayDataIdx < 2; $arrayDataIdx++) {
| $arrayData $array = $arrays[$arrayDataIdx];
| for (int $i = 0; $i < $array.numElements(); $i++) {
| if ($array.isNullAt($i)) {
| if (!$foundNullElement) {
| ${ev.value}.setNullAt($pos++);
| $foundNullElement = true;
| }
| } else {
| $javaTypeName $value = $array.$getter;
| if (!$hs.contains($castOp $value)) {
| $hs.add$postFix($value);
| ${ev.value}.$setter;
| $pos++;
| }
| }
| }
|}
""".stripMargin
} else {
val arrayUnion = classOf[ArrayUnion].getName
val et = ctx.addReferenceObj("elementTypeUnion", elementType)
val order = ctx.addReferenceObj("orderingUnion", ordering)
val method = "unionOrdering"
s"${ev.value} = $arrayUnion$$.MODULE$$.$method($array1, $array2, $et, $order);"
}
})
}

override def prettyName: String = "array_union"
}

object ArrayUnion {
def unionOrdering(
array1: ArrayData,
array2: ArrayData,
elementType: DataType,
ordering: Ordering[Any]): ArrayData = {
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
var alreadyIncludeNull = false
Seq(array1, array2).foreach(_.foreach(elementType, (_, elem) => {
var found = false
if (elem == null) {
if (alreadyIncludeNull) {
found = true
} else {
alreadyIncludeNull = true
}
} else {
// check elem is already stored in arrayBuffer or not?
var j = 0
while (!found && j < arrayBuffer.size) {
val va = arrayBuffer(j)
if (va != null && ordering.equiv(va, elem)) {
found = true
}
j = j + 1
}
}
if (!found) {
if (arrayBuffer.length > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.length)
}
arrayBuffer += elem
}
}))
new GenericArrayData(arrayBuffer)
}
}
Loading