Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3767,230 +3767,159 @@ object ArraySetLike {
""",
since = "2.4.0")
case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike
with ComplexTypeMergingExpression {
var hsInt: OpenHashSet[Int] = _
var hsLong: OpenHashSet[Long] = _

def assignInt(array: ArrayData, idx: Int, resultArray: ArrayData, pos: Int): Boolean = {
val elem = array.getInt(idx)
if (!hsInt.contains(elem)) {
if (resultArray != null) {
resultArray.setInt(pos, elem)
}
hsInt.add(elem)
true
} else {
false
}
}

def assignLong(array: ArrayData, idx: Int, resultArray: ArrayData, pos: Int): Boolean = {
val elem = array.getLong(idx)
if (!hsLong.contains(elem)) {
if (resultArray != null) {
resultArray.setLong(pos, elem)
}
hsLong.add(elem)
true
} else {
false
}
}
with ComplexTypeMergingExpression {

def evalIntLongPrimitiveType(
array1: ArrayData,
array2: ArrayData,
resultArray: ArrayData,
isLongType: Boolean): Int = {
// store elements into resultArray
var nullElementSize = 0
var pos = 0
Seq(array1, array2).foreach { array =>
var i = 0
while (i < array.numElements()) {
val size = if (!isLongType) hsInt.size else hsLong.size
if (size + nullElementSize > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
ArraySetLike.throwUnionLengthOverflowException(size)
}
if (array.isNullAt(i)) {
if (nullElementSize == 0) {
if (resultArray != null) {
resultArray.setNullAt(pos)
@transient lazy val evalUnion: (ArrayData, ArrayData) => ArrayData = {
if (elementTypeSupportEquals) {
(array1, array2) =>
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
val hs = new OpenHashSet[Any]
var foundNullElement = false
Seq(array1, array2).foreach { array =>
var i = 0
while (i < array.numElements()) {
if (array.isNullAt(i)) {
if (!foundNullElement) {
arrayBuffer += null
foundNullElement = true
}
} else {
val elem = array.get(i, elementType)
if (!hs.contains(elem)) {
if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.size)
}
arrayBuffer += elem
hs.add(elem)
}
}
pos += 1
nullElementSize = 1
i += 1
}
} else {
val assigned = if (!isLongType) {
assignInt(array, i, resultArray, pos)
}
new GenericArrayData(arrayBuffer)
} else {
(array1, array2) =>
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
var alreadyIncludeNull = false
Seq(array1, array2).foreach(_.foreach(elementType, (_, elem) => {
var found = false
if (elem == null) {
if (alreadyIncludeNull) {
found = true
} else {
alreadyIncludeNull = true
}
} else {
assignLong(array, i, resultArray, pos)
// check elem is already stored in arrayBuffer or not?
var j = 0
while (!found && j < arrayBuffer.size) {
val va = arrayBuffer(j)
if (va != null && ordering.equiv(va, elem)) {
found = true
}
j = j + 1
}
}
if (assigned) {
pos += 1
if (!found) {
if (arrayBuffer.length > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.length)
}
arrayBuffer += elem
}
}
i += 1
}
}))
new GenericArrayData(arrayBuffer)
}
pos
}

override def nullSafeEval(input1: Any, input2: Any): Any = {
val array1 = input1.asInstanceOf[ArrayData]
val array2 = input2.asInstanceOf[ArrayData]

if (elementTypeSupportEquals) {
elementType match {
case IntegerType =>
// avoid boxing of primitive int array elements
// calculate result array size
hsInt = new OpenHashSet[Int]
val elements = evalIntLongPrimitiveType(array1, array2, null, false)
hsInt = new OpenHashSet[Int]
val resultArray = if (UnsafeArrayData.shouldUseGenericArrayData(
IntegerType.defaultSize, elements)) {
new GenericArrayData(new Array[Any](elements))
} else {
UnsafeArrayData.forPrimitiveArray(
Platform.INT_ARRAY_OFFSET, elements, IntegerType.defaultSize)
}
evalIntLongPrimitiveType(array1, array2, resultArray, false)
resultArray
case LongType =>
// avoid boxing of primitive long array elements
// calculate result array size
hsLong = new OpenHashSet[Long]
val elements = evalIntLongPrimitiveType(array1, array2, null, true)
hsLong = new OpenHashSet[Long]
val resultArray = if (UnsafeArrayData.shouldUseGenericArrayData(
LongType.defaultSize, elements)) {
new GenericArrayData(new Array[Any](elements))
} else {
UnsafeArrayData.forPrimitiveArray(
Platform.LONG_ARRAY_OFFSET, elements, LongType.defaultSize)
}
evalIntLongPrimitiveType(array1, array2, resultArray, true)
resultArray
case _ =>
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
val hs = new OpenHashSet[Any]
var foundNullElement = false
Seq(array1, array2).foreach { array =>
var i = 0
while (i < array.numElements()) {
if (array.isNullAt(i)) {
if (!foundNullElement) {
arrayBuffer += null
foundNullElement = true
}
} else {
val elem = array.get(i, elementType)
if (!hs.contains(elem)) {
if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.size)
}
arrayBuffer += elem
hs.add(elem)
}
}
i += 1
}
}
new GenericArrayData(arrayBuffer)
}
} else {
ArrayUnion.unionOrdering(array1, array2, elementType, ordering)
}
evalUnion(array1, array2)
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val i = ctx.freshName("i")
val pos = ctx.freshName("pos")
val value = ctx.freshName("value")
val size = ctx.freshName("size")
val (postFix, openHashElementType, getter, setter, javaTypeName, castOp, arrayBuilder) =
if (elementTypeSupportEquals) {
elementType match {
case ByteType | ShortType | IntegerType | LongType =>
val ptName = CodeGenerator.primitiveTypeName(elementType)
val unsafeArray = ctx.freshName("unsafeArray")
(if (elementType == LongType) s"$$mcJ$$sp" else s"$$mcI$$sp",
if (elementType == LongType) "Long" else "Int",
s"get$ptName($i)", s"set$ptName($pos, $value)", CodeGenerator.javaType(elementType),
if (elementType == LongType) "(long)" else "(int)",
s"""
|${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")}
|${ev.value} = $unsafeArray;
""".stripMargin)
case _ =>
val genericArrayData = classOf[GenericArrayData].getName
val et = ctx.addReferenceObj("elementType", elementType)
("", "Object",
s"get($i, $et)", s"update($pos, $value)", "Object", "",
s"${ev.value} = new $genericArrayData(new Object[$size]);")
}
} else {
("", "", "", "", "", "", "")
}
if (canUseSpecializedHashSet) {
val jt = CodeGenerator.javaType(elementType)
val ptName = CodeGenerator.primitiveTypeName(jt)

nullSafeCodeGen(ctx, ev, (array1, array2) => {
if (openHashElementType != "") {
// Here, we ensure elementTypeSupportEquals is true
nullSafeCodeGen(ctx, ev, (array1, array2) => {
val foundNullElement = ctx.freshName("foundNullElement")
val openHashSet = classOf[OpenHashSet[_]].getName
val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$openHashElementType()"
val hs = ctx.freshName("hs")
val arrayData = classOf[ArrayData].getName
val arrays = ctx.freshName("arrays")
val nullElementIndex = ctx.freshName("nullElementIndex")
val builder = ctx.freshName("builder")
val array = ctx.freshName("array")
val arrays = ctx.freshName("arrays")
val arrayDataIdx = ctx.freshName("arrayDataIdx")
val openHashSet = classOf[OpenHashSet[_]].getName
val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()"
val hashSet = ctx.freshName("hashSet")
val arrayBuilder = classOf[mutable.ArrayBuilder[_]].getName
val arrayBuilderClass = s"$arrayBuilder$$of$ptName"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not needed any more?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that we still need this to create an intermediate result array. The array is allocated at L3907 as other functions do.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, sorry, I misread.


def withArrayNullAssignment(body: String) =
if (dataType.asInstanceOf[ArrayType].containsNull) {
s"""
|if ($array.isNullAt($i)) {
| if (!$foundNullElement) {
| $nullElementIndex = $size;
| $foundNullElement = true;
| $size++;
| $builder.$$plus$$eq($nullValueHolder);
| }
|} else {
| $body
|}
""".stripMargin
} else {
body
}

val processArray = withArrayNullAssignment(
s"""
|$jt $value = ${genGetValue(array, i)};
|if (!$hashSet.contains($hsValueCast$value)) {
| if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
| break;
| }
| $hashSet.add$hsPostFix($hsValueCast$value);
| $builder.$$plus$$eq($value);
|}
""".stripMargin)

// Only need to track null element index when result array's element is nullable.
val declareNullTrackVariables = if (dataType.asInstanceOf[ArrayType].containsNull) {
s"""
|boolean $foundNullElement = false;
|int $nullElementIndex = -1;
""".stripMargin
} else {
""
}

s"""
|$openHashSet $hs = new $openHashSet$postFix($classTag);
|boolean $foundNullElement = false;
|$arrayData[] $arrays = new $arrayData[]{$array1, $array2};
|for (int $arrayDataIdx = 0; $arrayDataIdx < 2; $arrayDataIdx++) {
| $arrayData $array = $arrays[$arrayDataIdx];
| for (int $i = 0; $i < $array.numElements(); $i++) {
| if ($array.isNullAt($i)) {
| $foundNullElement = true;
| } else {
| $hs.add$postFix($array.$getter);
| }
| }
|}
|int $size = $hs.size() + ($foundNullElement ? 1 : 0);
|$arrayBuilder
|$hs = new $openHashSet$postFix($classTag);
|$foundNullElement = false;
|int $pos = 0;
|$openHashSet $hashSet = new $openHashSet$hsPostFix($classTag);
|$declareNullTrackVariables
|int $size = 0;
|$arrayBuilderClass $builder = new $arrayBuilderClass();
|ArrayData[] $arrays = new ArrayData[]{$array1, $array2};
|for (int $arrayDataIdx = 0; $arrayDataIdx < 2; $arrayDataIdx++) {
| $arrayData $array = $arrays[$arrayDataIdx];
| ArrayData $array = $arrays[$arrayDataIdx];
| for (int $i = 0; $i < $array.numElements(); $i++) {
| if ($array.isNullAt($i)) {
| if (!$foundNullElement) {
| ${ev.value}.setNullAt($pos++);
| $foundNullElement = true;
| }
| } else {
| $javaTypeName $value = $array.$getter;
| if (!$hs.contains($castOp $value)) {
| $hs.add$postFix($value);
| ${ev.value}.$setter;
| $pos++;
| }
| }
| $processArray
| }
|}
|${buildResultArray(builder, ev.value, size, nullElementIndex)}
""".stripMargin
} else {
val arrayUnion = classOf[ArrayUnion].getName
val et = ctx.addReferenceObj("elementTypeUnion", elementType)
val order = ctx.addReferenceObj("orderingUnion", ordering)
val method = "unionOrdering"
s"${ev.value} = $arrayUnion$$.MODULE$$.$method($array1, $array2, $et, $order);"
}
})
})
} else {
nullSafeCodeGen(ctx, ev, (array1, array2) => {
val expr = ctx.addReferenceObj("arrayUnionExpr", this)
s"${ev.value} = (ArrayData)$expr.nullSafeEval($array1, $array2);"
})
}
}

override def prettyName: String = "array_union"
Expand Down Expand Up @@ -4154,7 +4083,6 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetL
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val arrayData = classOf[ArrayData].getName
val i = ctx.freshName("i")
val value = ctx.freshName("value")
val size = ctx.freshName("size")
Expand Down Expand Up @@ -4268,7 +4196,7 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetL
} else {
nullSafeCodeGen(ctx, ev, (array1, array2) => {
val expr = ctx.addReferenceObj("arrayIntersectExpr", this)
s"${ev.value} = ($arrayData)$expr.nullSafeEval($array1, $array2);"
s"${ev.value} = (ArrayData)$expr.nullSafeEval($array1, $array2);"
})
}
}
Expand Down Expand Up @@ -4387,7 +4315,6 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val arrayData = classOf[ArrayData].getName
val i = ctx.freshName("i")
val value = ctx.freshName("value")
val size = ctx.freshName("size")
Expand Down Expand Up @@ -4490,7 +4417,7 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike
} else {
nullSafeCodeGen(ctx, ev, (array1, array2) => {
val expr = ctx.addReferenceObj("arrayExceptExpr", this)
s"${ev.value} = ($arrayData)$expr.nullSafeEval($array1, $array2);"
s"${ev.value} = (ArrayData)$expr.nullSafeEval($array1, $array2);"
})
}
}
Expand Down
Loading