From dc9d6f0f2bec0d44a0ac18d6736d11b35f7c597b Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 13 Apr 2018 07:08:06 +0100 Subject: [PATCH 01/34] initial commit --- python/pyspark/sql/functions.py | 18 +++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 103 +++++++++++++++--- .../CollectionExpressionsSuite.scala | 38 +++++++ .../org/apache/spark/sql/functions.scala | 11 ++ .../spark/sql/DataFrameFunctionsSuite.scala | 41 +++++++ 6 files changed, 199 insertions(+), 13 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 9652d3e79b875..027e1db6e7c73 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2013,6 +2013,24 @@ def array_distinct(col): return Column(sc._jvm.functions.array_distinct(_to_java_column(col))) +@ignore_unicode_prefix +@since(2.4) +def array_union(col1, col2): + """ + Collection function: Returns an array of the elements in the union of col1 and col2, + without duplicates + + :param col1: name of column containing array + :param col2: name of column containing array + + >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])]) + >>> df.select(array_union(df.c1, df.c2)).collect() + [Row(array_union(c1, c2)=[u'b', u'a', u'c', u'd', u'f']))] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_union(_to_java_column(col1), _to_java_column(col2))) + + @since(1.4) def explode(col): """Returns a new row for each element in the given array or map. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index a574d8a84d4fb..981fed6b7268d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -414,6 +414,7 @@ object FunctionRegistry { expression[ArrayJoin]("array_join"), expression[ArrayPosition]("array_position"), expression[ArraySort]("array_sort"), + expression[ArrayUnion]("array_union"), expression[CreateMap]("map"), expression[CreateNamedStruct]("named_struct"), expression[ElementAt]("element_at"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 8b278f067749e..b8368bbbdf538 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -2021,7 +2021,7 @@ case class Concat(children: Seq[Expression]) extends Expression { ByteArray.concat(inputs: _*) case StringType => val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) - UTF8String.concat(inputs : _*) + UTF8String.concat(inputs: _*) case ArrayType(elementType, _) => val inputs = children.toStream.map(_.eval(input)) if (inputs.contains(null)) { @@ -2035,7 +2035,7 @@ case class Concat(children: Seq[Expression]) extends Expression { } val finalData = new Array[AnyRef](numberOfElements.toInt) var position = 0 - for(ad <- arrayData) { + for (ad <- arrayData) { val arr = ad.toObjectArray(elementType) Array.copy(arr, 0, finalData, position, arr.length) position += arr.length @@ -2082,23 +2082,24 @@ case class Concat(children: Seq[Expression]) extends Expression { """) } - private def genCodeForNumberOfElements(ctx: CodegenContext) : (String, String) = { + private def genCodeForNumberOfElements(ctx: CodegenContext): (String, String) = { val numElements = ctx.freshName("numElements") - val code = s""" - |long $numElements = 0L; - |for (int z = 0; z < ${children.length}; z++) { - | $numElements += args[z].numElements(); - |} - |if ($numElements > $MAX_ARRAY_LENGTH) { - | throw new RuntimeException("Unsuccessful try to concat arrays with " + $numElements + - | " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH."); - |} + val code = + s""" + |long $numElements = 0L; + |for (int z = 0; z < ${children.length}; z++) { + | $numElements += args[z].numElements(); + |} + |if ($numElements > $MAX_ARRAY_LENGTH) { + | throw new RuntimeException("Unsuccessful try to concat arrays with " + $numElements + + | " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH."); + |} """.stripMargin (code, numElements) } - private def nullArgumentProtection() : String = { + private def nullArgumentProtection(): String = { if (nullable) { s""" |for (int z = 0; z < ${children.length}; z++) { @@ -2858,7 +2859,83 @@ case class ArrayRepeat(left: Expression, right: Expression) |$arrayDataName = new $genericArrayClass($arrayName); """.stripMargin } +} + +/** + * Returns an array of the elements in the union of x and y, without duplicates + */ +@ExpressionDescription( + usage = """ + _FUNC_(array, value) - Returns an array of the elements in the union of x and y, + without duplicates. + """, + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5)); + array(1, 2, 3, 5) + """, + since = "2.4.0") +case class ArrayUnion(left: Expression, right: Expression) + extends BinaryExpression with ExpectsInputTypes with CodegenFallback { + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, ArrayType) + + override def checkInputDataTypes(): TypeCheckResult = { + val r = super.checkInputDataTypes() + if ((r == TypeCheckResult.TypeCheckSuccess) && + (left.dataType.asInstanceOf[ArrayType].elementType != + right.dataType.asInstanceOf[ArrayType].elementType)) { + TypeCheckResult.TypeCheckFailure("Element type in both arrays must be the same") + } else { + r + } + } + + override def dataType: DataType = left.dataType + + override def nullSafeEval(linput: Any, rinput: Any): Any = { + val elementType = dataType.asInstanceOf[ArrayType].elementType + val cnl = left.dataType.asInstanceOf[ArrayType].containsNull + val cnr = right.dataType.asInstanceOf[ArrayType].containsNull + val larray = linput.asInstanceOf[ArrayData] + val rarray = rinput.asInstanceOf[ArrayData] + + if (!cnl && !cnr && elementType == IntegerType) { + // avoid boxing primitive int array elements + val hs = new OpenHashSet[Int] + var i = 0 + while (i < larray.numElements()) { + hs.add(larray.getInt(i)) + i += 1 + } + i = 0 + while (i < rarray.numElements()) { + hs.add(rarray.getInt(i)) + i += 1 + } + UnsafeArrayData.fromPrimitiveArray(hs.iterator.toArray) + } else if (!cnl && !cnr && elementType == LongType) { + // avoid boxing of primitive long array elements + val hs = new OpenHashSet[Long] + var i = 0 + while (i < larray.numElements()) { + hs.add(larray.getLong(i)) + i += 1 + } + i = 0 + while (i < rarray.numElements()) { + hs.add(rarray.getLong(i)) + i += 1 + } + UnsafeArrayData.fromPrimitiveArray(hs.iterator.toArray) + } else { + new GenericArrayData( + (larray.toArray[AnyRef](elementType) union rarray.toArray[AnyRef](elementType)) + .distinct.asInstanceOf[Array[Any]]) + } + } + override def prettyName: String = "array_union" } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index d7744eb4c7dc7..3a17b2c5c8144 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1166,4 +1166,42 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayDistinct(c1), Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1))) checkEvaluation(ArrayDistinct(c2), Seq[Seq[Int]](null, Seq[Int](2, 1))) } + + test("Array Union") { + val a00 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, false)) + val a01 = Literal.create(Seq(4, 2), ArrayType(IntegerType, false)) + val a02 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) + val a03 = Literal.create(Seq(1, 2, null, 4, 5), ArrayType(IntegerType)) + val a04 = Literal.create(Seq(-5, 4, -3, 2, -1), ArrayType(IntegerType)) + val a05 = Literal.create(Seq.empty[Int], ArrayType(IntegerType)) + + val a10 = Literal.create(Seq(1L, 2L, 3L), ArrayType(LongType, false)) + val a11 = Literal.create(Seq(4L, 2L), ArrayType(LongType, false)) + val a12 = Literal.create(Seq(1L, 2L, 3L), ArrayType(LongType)) + val a13 = Literal.create(Seq(1L, 2L, null, 4L, 5L), ArrayType(LongType)) + val a14 = Literal.create(Seq(-5L, 4L, -3L, 2L, -1L), ArrayType(LongType)) + val a15 = Literal.create(Seq.empty[Long], ArrayType(LongType)) + + val a20 = Literal.create(Seq("b", "a", "c"), ArrayType(StringType)) + val a21 = Literal.create(Seq("c", "d", "a", "f"), ArrayType(StringType)) + val a22 = Literal.create(Seq("b", null, "a", "g"), ArrayType(StringType)) + + val a30 = Literal.create(Seq(null, null), ArrayType(NullType)) + + checkEvaluation(ArrayUnion(a00, a01), UnsafeArrayData.fromPrimitiveArray(Array(4, 1, 3, 2))) + checkEvaluation(ArrayUnion(a01, a02), Seq(4, 2, 1, 3)) + checkEvaluation(ArrayUnion(a03, a04), Seq(1, 2, null, 4, 5, -5, -3, -1)) + checkEvaluation(ArrayUnion(a03, a05), Seq(1, 2, null, 4, 5)) + + checkEvaluation( + ArrayUnion(a10, a11), UnsafeArrayData.fromPrimitiveArray(Array(4L, 1L, 3L, 2L))) + checkEvaluation(ArrayUnion(a11, a12), Seq(4L, 2L, 1L, 3L)) + checkEvaluation(ArrayUnion(a13, a14), Seq(1L, 2L, null, 4L, 5L, -5L, -3L, -1L)) + checkEvaluation(ArrayUnion(a13, a15), Seq(1L, 2L, null, 4L, 5L)) + + checkEvaluation(ArrayUnion(a20, a21), Seq("b", "a", "c", "d", "f")) + checkEvaluation(ArrayUnion(a20, a22), Seq("b", "a", "c", null, "g")) + + checkEvaluation(ArrayUnion(a30, a30), Seq(null)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 0b4f526799578..9470364736b13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3182,6 +3182,7 @@ object functions { /** * Remove all elements that equal to element from the given array. + * * @group collection_funcs * @since 2.4.0 */ @@ -3196,6 +3197,16 @@ object functions { */ def array_distinct(e: Column): Column = withExpr { ArrayDistinct(e.expr) } + /** + * Returns an array of the elements in the union of the given two arrays, without duplicates. + * + * @group collection_funcs + * @since 2.4.0 + */ + def array_union(col1: Column, col2: Column): Column = withExpr { + ArrayUnion(col1.expr, col2.expr) + } + /** * Creates a new row for each element in the given array or map column. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 4c28e2f1cd909..dbf4fdc443533 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1120,6 +1120,47 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { "argument 1 requires (array or map) type, however, '`_1`' is of string type")) } + test("array union functions") { + val df1 = Seq((Array(1, 2, 3), Array(4, 2))).toDF("a", "b") + val ans1 = Row(Seq(4, 1, 3, 2)) + checkAnswer(df1.select(array_union($"a", $"b")), ans1) + checkAnswer(df1.selectExpr("array_union(a, b)"), ans1) + + val df2 = Seq((Array[Integer](1, 2, null, 4, 5), Array(-5, 4, -3, 2, -1))).toDF("a", "b") + val ans2 = Row(Seq(1, 2, null, 4, 5, -5, -3, -1)) + checkAnswer(df2.select(array_union($"a", $"b")), ans2) + checkAnswer(df2.selectExpr("array_union(a, b)"), ans2) + + val df3 = Seq((Array(1L, 2L, 3L), Array(4L, 2L))).toDF("a", "b") + val ans3 = Row(Seq(4L, 1L, 3L, 2L)) + checkAnswer(df3.select(array_union($"a", $"b")), ans3) + checkAnswer(df3.selectExpr("array_union(a, b)"), ans3) + + val df4 = Seq((Array[java.lang.Long](1L, 2L, null, 4L, 5L), Array(-5L, 4L, -3L, 2L, -1L))) + .toDF("a", "b") + val ans4 = Row(Seq(1L, 2L, null, 4L, 5L, -5L, -3L, -1L)) + checkAnswer(df4.select(array_union($"a", $"b")), ans4) + checkAnswer(df4.selectExpr("array_union(a, b)"), ans4) + + val df5 = Seq((Array("b", "a", "c"), Array("b", null, "a", "g"))).toDF("a", "b") + val ans5 = Row(Seq("b", "a", "c", null, "g")) + checkAnswer(df5.select(array_union($"a", $"b")), ans5) + checkAnswer(df5.selectExpr("array_union(a, b)"), ans5) + + val df6 = Seq((null, null)).toDF("a", "b") + val ans6 = Row(null) + checkAnswer(df6.select(array_union($"a", $"b")), ans6) + checkAnswer(df6.selectExpr("array_union(a, b)"), ans6) + + val df0 = Seq((Array(1), Array("a"))).toDF("a", "b") + intercept[AnalysisException] { + df0.select(array_union($"a", $"b")) + } + intercept[AnalysisException] { + df0.selectExpr("array_contains(a, b)") + } + } + test("concat function - arrays") { val nseqi : Seq[Int] = null val nseqs : Seq[String] = null From 301984061cf406537122f72c44c1e2fab66aa0de Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 13 Apr 2018 07:20:34 +0100 Subject: [PATCH 02/34] update description --- .../spark/sql/catalyst/expressions/collectionOperations.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index b8368bbbdf538..23a618687cf05 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -2866,7 +2866,7 @@ case class ArrayRepeat(left: Expression, right: Expression) */ @ExpressionDescription( usage = """ - _FUNC_(array, value) - Returns an array of the elements in the union of x and y, + _FUNC_(array1, array2) - Returns an array of the elements in the union of array1 and array2, without duplicates. """, examples = """ From 8cee6cff412bd62a60cb81ceb1d86269d5094bf5 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 13 Apr 2018 11:51:59 +0100 Subject: [PATCH 03/34] fix test failure --- python/pyspark/sql/functions.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 027e1db6e7c73..6387a9372388a 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2023,6 +2023,7 @@ def array_union(col1, col2): :param col1: name of column containing array :param col2: name of column containing array + >>> from pyspark.sql import Row >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])]) >>> df.select(array_union(df.c1, df.c2)).collect() [Row(array_union(c1, c2)=[u'b', u'a', u'c', u'd', u'f']))] From 2041ec45efdcb2b3ae9dfc7c5b7c6dc26c0091ea Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 17 Apr 2018 18:16:18 +0100 Subject: [PATCH 04/34] address review comments --- python/pyspark/sql/functions.py | 7 +- .../expressions/collectionOperations.scala | 133 +++++++++++++----- .../CollectionExpressionsSuite.scala | 3 + .../org/apache/spark/sql/functions.scala | 1 + 4 files changed, 108 insertions(+), 36 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 6387a9372388a..20931bf0a90d3 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1942,6 +1942,7 @@ def concat(*cols): return Column(sc._jvm.functions.concat(_to_seq(sc, cols, _to_java_column))) +@ignore_unicode_prefix @since(2.4) def array_position(col, value): """ @@ -2017,8 +2018,8 @@ def array_distinct(col): @since(2.4) def array_union(col1, col2): """ - Collection function: Returns an array of the elements in the union of col1 and col2, - without duplicates + Collection function: returns an array of the elements in the union of col1 and col2, + without duplicates. The order of elements in the result is not determined. :param col1: name of column containing array :param col2: name of column containing array @@ -2026,7 +2027,7 @@ def array_union(col1, col2): >>> from pyspark.sql import Row >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])]) >>> df.select(array_union(df.c1, df.c2)).collect() - [Row(array_union(c1, c2)=[u'b', u'a', u'c', u'd', u'f']))] + [Row(array_union(c1, c2)=[u'b', u'c', u'd', u'a', u'f']))] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.array_union(_to_java_column(col1), _to_java_column(col2))) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 23a618687cf05..4adffd3e37a7b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -2876,7 +2876,7 @@ case class ArrayRepeat(left: Expression, right: Expression) """, since = "2.4.0") case class ArrayUnion(left: Expression, right: Expression) - extends BinaryExpression with ExpectsInputTypes with CodegenFallback { + extends BinaryExpression with ExpectsInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, ArrayType) @@ -2893,46 +2893,106 @@ case class ArrayUnion(left: Expression, right: Expression) override def dataType: DataType = left.dataType + private def elementType = dataType.asInstanceOf[ArrayType].elementType + private def cnLeft = left.dataType.asInstanceOf[ArrayType].containsNull + private def cnRight = right.dataType.asInstanceOf[ArrayType].containsNull + override def nullSafeEval(linput: Any, rinput: Any): Any = { - val elementType = dataType.asInstanceOf[ArrayType].elementType - val cnl = left.dataType.asInstanceOf[ArrayType].containsNull - val cnr = right.dataType.asInstanceOf[ArrayType].containsNull val larray = linput.asInstanceOf[ArrayData] val rarray = rinput.asInstanceOf[ArrayData] - if (!cnl && !cnr && elementType == IntegerType) { - // avoid boxing primitive int array elements - val hs = new OpenHashSet[Int] - var i = 0 - while (i < larray.numElements()) { - hs.add(larray.getInt(i)) - i += 1 - } - i = 0 - while (i < rarray.numElements()) { - hs.add(rarray.getInt(i)) - i += 1 - } - UnsafeArrayData.fromPrimitiveArray(hs.iterator.toArray) - } else if (!cnl && !cnr && elementType == LongType) { - // avoid boxing of primitive long array elements - val hs = new OpenHashSet[Long] - var i = 0 - while (i < larray.numElements()) { - hs.add(larray.getLong(i)) - i += 1 + if (!cnLeft && !cnRight) { + elementType match { + case IntegerType => + // avoid boxing of primitive int array elements + val hs = new OpenHashSet[Int] + var i = 0 + while (i < larray.numElements()) { + hs.add(larray.getInt(i)) + i += 1 + } + i = 0 + while (i < rarray.numElements()) { + hs.add(rarray.getInt(i)) + i += 1 + } + UnsafeArrayData.fromPrimitiveArray(hs.iterator.toArray) + case LongType => + // avoid boxing of primitive long array elements + val hs = new OpenHashSet[Long] + var i = 0 + while (i < larray.numElements()) { + hs.add(larray.getLong(i)) + i += 1 + } + i = 0 + while (i < rarray.numElements()) { + hs.add(rarray.getLong(i)) + i += 1 + } + UnsafeArrayData.fromPrimitiveArray(hs.iterator.toArray) + case _ => + val hs = new OpenHashSet[Any] + var i = 0 + while (i < larray.numElements()) { + hs.add(larray.get(i, elementType)) + i += 1 + } + i = 0 + while (i < rarray.numElements()) { + hs.add(rarray.get(i, elementType)) + i += 1 + } + new GenericArrayData(hs.iterator.toArray) } - i = 0 - while (i < rarray.numElements()) { - hs.add(rarray.getLong(i)) - i += 1 + } else { + CollectionOperations.arrayUnion(larray, rarray, elementType) + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val hs = ctx.freshName("hs") + val i = ctx.freshName("i") + val collectionOperations = "org.apache.spark.sql.catalyst.expressions.CollectionOperations" + val genericArrayData = classOf[GenericArrayData].getName + val unsafeArrayData = classOf[UnsafeArrayData].getName + val openHashSet = classOf[OpenHashSet[_]].getName + val ot = "org.apache.spark.sql.types.ObjectType$.MODULE$.apply(Object.class)" + val (postFix, classTag, getter, arrayBuilder, castType) = if (!cnLeft && !cnRight) { + val ptName = CodeGenerator.primitiveTypeName(elementType) + elementType match { + case ByteType | ShortType | IntegerType => + (s"$$mcI$$sp", s"scala.reflect.ClassTag$$.MODULE$$.$ptName()", s"get$ptName($i)", + s"$unsafeArrayData.fromPrimitiveArray", CodeGenerator.javaType(elementType)) + case LongType => + (s"$$mcJ$$sp", s"scala.reflect.ClassTag$$.MODULE$$.$ptName()", s"get$ptName($i)", + s"$unsafeArrayData.fromPrimitiveArray", "long") + case _ => + ("", s"scala.reflect.ClassTag$$.MODULE$$.Object()", s"get($i, $ot)", + s"new $genericArrayData", "Object") } - UnsafeArrayData.fromPrimitiveArray(hs.iterator.toArray) } else { - new GenericArrayData( - (larray.toArray[AnyRef](elementType) union rarray.toArray[AnyRef](elementType)) - .distinct.asInstanceOf[Array[Any]]) + ("", "", "", "", "") } + + nullSafeCodeGen(ctx, ev, (larray, rarray) => { + if (classTag != "") { + s""" + |$openHashSet $hs = new $openHashSet$postFix($classTag); + |for (int $i = 0; $i < $larray.numElements(); $i++) { + | $hs.add$postFix($larray.$getter); + |} + |for (int $i = 0; $i < $rarray.numElements(); $i++) { + | $hs.add$postFix($rarray.$getter); + |} + |${ev.value} = $arrayBuilder( + | ($castType[]) $hs.iterator().toArray($classTag)); + """.stripMargin + } else { + val dt = "org.apache.spark.sql.types.ObjectType$.MODULE$.apply(Object.class)" + s"${ev.value} = $collectionOperations$$.MODULE$$.arrayUnion($larray, $rarray, $ot);" + } + }) } override def prettyName: String = "array_union" @@ -3338,3 +3398,10 @@ case class ArrayDistinct(child: Expression) override def prettyName: String = "array_distinct" } + +object CollectionOperations { + def arrayUnion(larray: ArrayData, rarray: ArrayData, et: DataType): ArrayData = { + new GenericArrayData(larray.toArray[AnyRef](et).union(rarray.toArray[AnyRef](et)) + .distinct.asInstanceOf[Array[Any]]) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 3a17b2c5c8144..332dff183ac99 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1185,6 +1185,8 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val a20 = Literal.create(Seq("b", "a", "c"), ArrayType(StringType)) val a21 = Literal.create(Seq("c", "d", "a", "f"), ArrayType(StringType)) val a22 = Literal.create(Seq("b", null, "a", "g"), ArrayType(StringType)) + val a23 = Literal.create(Seq("b", "a", "c"), ArrayType(StringType, false)) + val a24 = Literal.create(Seq("c", "d", "a", "f"), ArrayType(StringType, false)) val a30 = Literal.create(Seq(null, null), ArrayType(NullType)) @@ -1201,6 +1203,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayUnion(a20, a21), Seq("b", "a", "c", "d", "f")) checkEvaluation(ArrayUnion(a20, a22), Seq("b", "a", "c", null, "g")) + checkEvaluation(ArrayUnion(a23, a24), Seq("b", "c", "d", "a", "f")) checkEvaluation(ArrayUnion(a30, a30), Seq(null)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 9470364736b13..33223f0fe8cc0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3199,6 +3199,7 @@ object functions { /** * Returns an array of the elements in the union of the given two arrays, without duplicates. + * The order of elements in the result is not determined * * @group collection_funcs * @since 2.4.0 From 8c2280be254769b51342c6afa41801e88c6b0bee Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 17 Apr 2018 20:31:03 +0100 Subject: [PATCH 05/34] introduce ArraySetUtils to reuse code among array_union/array_intersect/array_except --- .../expressions/collectionOperations.scala | 293 +++++++++--------- .../spark/sql/DataFrameFunctionsSuite.scala | 2 +- 2 files changed, 154 insertions(+), 141 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 4adffd3e37a7b..29a1c4a52b8da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -2861,143 +2861,6 @@ case class ArrayRepeat(left: Expression, right: Expression) } } -/** - * Returns an array of the elements in the union of x and y, without duplicates - */ -@ExpressionDescription( - usage = """ - _FUNC_(array1, array2) - Returns an array of the elements in the union of array1 and array2, - without duplicates. - """, - examples = """ - Examples: - > SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5)); - array(1, 2, 3, 5) - """, - since = "2.4.0") -case class ArrayUnion(left: Expression, right: Expression) - extends BinaryExpression with ExpectsInputTypes { - - override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, ArrayType) - - override def checkInputDataTypes(): TypeCheckResult = { - val r = super.checkInputDataTypes() - if ((r == TypeCheckResult.TypeCheckSuccess) && - (left.dataType.asInstanceOf[ArrayType].elementType != - right.dataType.asInstanceOf[ArrayType].elementType)) { - TypeCheckResult.TypeCheckFailure("Element type in both arrays must be the same") - } else { - r - } - } - - override def dataType: DataType = left.dataType - - private def elementType = dataType.asInstanceOf[ArrayType].elementType - private def cnLeft = left.dataType.asInstanceOf[ArrayType].containsNull - private def cnRight = right.dataType.asInstanceOf[ArrayType].containsNull - - override def nullSafeEval(linput: Any, rinput: Any): Any = { - val larray = linput.asInstanceOf[ArrayData] - val rarray = rinput.asInstanceOf[ArrayData] - - if (!cnLeft && !cnRight) { - elementType match { - case IntegerType => - // avoid boxing of primitive int array elements - val hs = new OpenHashSet[Int] - var i = 0 - while (i < larray.numElements()) { - hs.add(larray.getInt(i)) - i += 1 - } - i = 0 - while (i < rarray.numElements()) { - hs.add(rarray.getInt(i)) - i += 1 - } - UnsafeArrayData.fromPrimitiveArray(hs.iterator.toArray) - case LongType => - // avoid boxing of primitive long array elements - val hs = new OpenHashSet[Long] - var i = 0 - while (i < larray.numElements()) { - hs.add(larray.getLong(i)) - i += 1 - } - i = 0 - while (i < rarray.numElements()) { - hs.add(rarray.getLong(i)) - i += 1 - } - UnsafeArrayData.fromPrimitiveArray(hs.iterator.toArray) - case _ => - val hs = new OpenHashSet[Any] - var i = 0 - while (i < larray.numElements()) { - hs.add(larray.get(i, elementType)) - i += 1 - } - i = 0 - while (i < rarray.numElements()) { - hs.add(rarray.get(i, elementType)) - i += 1 - } - new GenericArrayData(hs.iterator.toArray) - } - } else { - CollectionOperations.arrayUnion(larray, rarray, elementType) - } - } - - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val hs = ctx.freshName("hs") - val i = ctx.freshName("i") - val collectionOperations = "org.apache.spark.sql.catalyst.expressions.CollectionOperations" - val genericArrayData = classOf[GenericArrayData].getName - val unsafeArrayData = classOf[UnsafeArrayData].getName - val openHashSet = classOf[OpenHashSet[_]].getName - val ot = "org.apache.spark.sql.types.ObjectType$.MODULE$.apply(Object.class)" - val (postFix, classTag, getter, arrayBuilder, castType) = if (!cnLeft && !cnRight) { - val ptName = CodeGenerator.primitiveTypeName(elementType) - elementType match { - case ByteType | ShortType | IntegerType => - (s"$$mcI$$sp", s"scala.reflect.ClassTag$$.MODULE$$.$ptName()", s"get$ptName($i)", - s"$unsafeArrayData.fromPrimitiveArray", CodeGenerator.javaType(elementType)) - case LongType => - (s"$$mcJ$$sp", s"scala.reflect.ClassTag$$.MODULE$$.$ptName()", s"get$ptName($i)", - s"$unsafeArrayData.fromPrimitiveArray", "long") - case _ => - ("", s"scala.reflect.ClassTag$$.MODULE$$.Object()", s"get($i, $ot)", - s"new $genericArrayData", "Object") - } - } else { - ("", "", "", "", "") - } - - nullSafeCodeGen(ctx, ev, (larray, rarray) => { - if (classTag != "") { - s""" - |$openHashSet $hs = new $openHashSet$postFix($classTag); - |for (int $i = 0; $i < $larray.numElements(); $i++) { - | $hs.add$postFix($larray.$getter); - |} - |for (int $i = 0; $i < $rarray.numElements(); $i++) { - | $hs.add$postFix($rarray.$getter); - |} - |${ev.value} = $arrayBuilder( - | ($castType[]) $hs.iterator().toArray($classTag)); - """.stripMargin - } else { - val dt = "org.apache.spark.sql.types.ObjectType$.MODULE$.apply(Object.class)" - s"${ev.value} = $collectionOperations$$.MODULE$$.arrayUnion($larray, $rarray, $ot);" - } - }) - } - - override def prettyName: String = "array_union" -} - /** * Remove all elements that equal to element from the given array */ @@ -3399,9 +3262,159 @@ case class ArrayDistinct(child: Expression) override def prettyName: String = "array_distinct" } -object CollectionOperations { - def arrayUnion(larray: ArrayData, rarray: ArrayData, et: DataType): ArrayData = { - new GenericArrayData(larray.toArray[AnyRef](et).union(rarray.toArray[AnyRef](et)) +abstract class ArraySetUtils extends BinaryExpression with ExpectsInputTypes { + val kindUnion = 1 + + def typeId: Int + + def array1: Expression + + def array2: Expression + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, ArrayType) + + override def checkInputDataTypes(): TypeCheckResult = { + val r = super.checkInputDataTypes() + if ((r == TypeCheckResult.TypeCheckSuccess) && + (array1.dataType.asInstanceOf[ArrayType].elementType != + array2.dataType.asInstanceOf[ArrayType].elementType)) { + TypeCheckResult.TypeCheckFailure("Element type in both arrays must be the same") + } else { + r + } + } + + override def dataType: DataType = array1.dataType + + private def elementType = dataType.asInstanceOf[ArrayType].elementType + + private def cn1 = array1.dataType.asInstanceOf[ArrayType].containsNull + + private def cn2 = array2.dataType.asInstanceOf[ArrayType].containsNull + + override def nullSafeEval(input1: Any, input2: Any): Any = { + val ary1 = input1.asInstanceOf[ArrayData] + val ary2 = input2.asInstanceOf[ArrayData] + + if (!cn1 && !cn2) { + elementType match { + case IntegerType => + // avoid boxing of primitive int array elements + val hs = new OpenHashSet[Int] + var i = 0 + while (i < ary1.numElements()) { + hs.add(ary1.getInt(i)) + i += 1 + } + i = 0 + while (i < ary2.numElements()) { + hs.add(ary2.getInt(i)) + i += 1 + } + UnsafeArrayData.fromPrimitiveArray(hs.iterator.toArray) + case LongType => + // avoid boxing of primitive long array elements + val hs = new OpenHashSet[Long] + var i = 0 + while (i < ary1.numElements()) { + hs.add(ary1.getLong(i)) + i += 1 + } + i = 0 + while (i < ary2.numElements()) { + hs.add(ary2.getLong(i)) + i += 1 + } + UnsafeArrayData.fromPrimitiveArray(hs.iterator.toArray) + case _ => + val hs = new OpenHashSet[Any] + var i = 0 + while (i < ary1.numElements()) { + hs.add(ary1.get(i, elementType)) + i += 1 + } + i = 0 + while (i < ary2.numElements()) { + hs.add(ary2.get(i, elementType)) + i += 1 + } + new GenericArrayData(hs.iterator.toArray) + } + } else { + ArraySetUtils.arrayUnion(ary1, ary2, elementType) + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val hs = ctx.freshName("hs") + val i = ctx.freshName("i") + val ArraySetUtils = "org.apache.spark.sql.catalyst.expressions.ArraySetUtils" + val genericArrayData = classOf[GenericArrayData].getName + val unsafeArrayData = classOf[UnsafeArrayData].getName + val openHashSet = classOf[OpenHashSet[_]].getName + val ot = "org.apache.spark.sql.types.ObjectType$.MODULE$.apply(Object.class)" + val (postFix, classTag, getter, arrayBuilder, castType) = if (!cn1 && !cn2) { + val ptName = CodeGenerator.primitiveTypeName(elementType) + elementType match { + case ByteType | ShortType | IntegerType => + (s"$$mcI$$sp", s"scala.reflect.ClassTag$$.MODULE$$.$ptName()", s"get$ptName($i)", + s"$unsafeArrayData.fromPrimitiveArray", CodeGenerator.javaType(elementType)) + case LongType => + (s"$$mcJ$$sp", s"scala.reflect.ClassTag$$.MODULE$$.$ptName()", s"get$ptName($i)", + s"$unsafeArrayData.fromPrimitiveArray", "long") + case _ => + ("", s"scala.reflect.ClassTag$$.MODULE$$.Object()", s"get($i, $ot)", + s"new $genericArrayData", "Object") + } + } else { + ("", "", "", "", "") + } + + nullSafeCodeGen(ctx, ev, (ary1, ary2) => { + if (classTag != "") { + s""" + |$openHashSet $hs = new $openHashSet$postFix($classTag); + |for (int $i = 0; $i < $ary1.numElements(); $i++) { + | $hs.add$postFix($ary1.$getter); + |} + |for (int $i = 0; $i < $ary2.numElements(); $i++) { + | $hs.add$postFix($ary2.$getter); + |} + |${ev.value} = $arrayBuilder(($castType[]) $hs.iterator().toArray($classTag)); + """.stripMargin + } else { + val dt = "org.apache.spark.sql.types.ObjectType$.MODULE$.apply(Object.class)" + s"${ev.value} = $ArraySetUtils$$.MODULE$$.arrayUnion($ary1, $ary2, $ot);" + } + }) + } +} + +object ArraySetUtils { + def arrayUnion(array1: ArrayData, array2: ArrayData, et: DataType): ArrayData = { + new GenericArrayData(array1.toArray[AnyRef](et).union(array2.toArray[AnyRef](et)) .distinct.asInstanceOf[Array[Any]]) } } + +/** + * Returns an array of the elements in the union of x and y, without duplicates + */ +@ExpressionDescription( + usage = """ + _FUNC_(array1, array2) - Returns an array of the elements in the union of array1 and array2, + without duplicates. The order of elements in the result is not determined. + """, + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5)); + array(1, 2, 3, 5) + """, + since = "2.4.0") +case class ArrayUnion(left: Expression, right: Expression) extends ArraySetUtils { + override def typeId: Int = kindUnion + override def array1: Expression = left + override def array2: Expression = right + + override def prettyName: String = "array_union" +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index dbf4fdc443533..1a1b91024be88 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1120,7 +1120,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { "argument 1 requires (array or map) type, however, '`_1`' is of string type")) } - test("array union functions") { + test("array_union functions") { val df1 = Seq((Array(1, 2, 3), Array(4, 2))).toDF("a", "b") val ans1 = Row(Seq(4, 1, 3, 2)) checkAnswer(df1.select(array_union($"a", $"b")), ans1) From b3a3132c1515fabce10a3782f83d86371335d5f3 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 18 Apr 2018 06:58:14 +0100 Subject: [PATCH 06/34] fix python test failure --- .../sql/catalyst/expressions/collectionOperations.scala | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 29a1c4a52b8da..3cdfcdbc1d3cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3352,7 +3352,7 @@ abstract class ArraySetUtils extends BinaryExpression with ExpectsInputTypes { val genericArrayData = classOf[GenericArrayData].getName val unsafeArrayData = classOf[UnsafeArrayData].getName val openHashSet = classOf[OpenHashSet[_]].getName - val ot = "org.apache.spark.sql.types.ObjectType$.MODULE$.apply(Object.class)" + val et = s"org.apache.spark.sql.types.DataTypes.$elementType" val (postFix, classTag, getter, arrayBuilder, castType) = if (!cn1 && !cn2) { val ptName = CodeGenerator.primitiveTypeName(elementType) elementType match { @@ -3363,7 +3363,7 @@ abstract class ArraySetUtils extends BinaryExpression with ExpectsInputTypes { (s"$$mcJ$$sp", s"scala.reflect.ClassTag$$.MODULE$$.$ptName()", s"get$ptName($i)", s"$unsafeArrayData.fromPrimitiveArray", "long") case _ => - ("", s"scala.reflect.ClassTag$$.MODULE$$.Object()", s"get($i, $ot)", + ("", s"scala.reflect.ClassTag$$.MODULE$$.Object()", s"get($i, $et)", s"new $genericArrayData", "Object") } } else { @@ -3383,8 +3383,7 @@ abstract class ArraySetUtils extends BinaryExpression with ExpectsInputTypes { |${ev.value} = $arrayBuilder(($castType[]) $hs.iterator().toArray($classTag)); """.stripMargin } else { - val dt = "org.apache.spark.sql.types.ObjectType$.MODULE$.apply(Object.class)" - s"${ev.value} = $ArraySetUtils$$.MODULE$$.arrayUnion($ary1, $ary2, $ot);" + s"${ev.value} = $ArraySetUtils$$.MODULE$$.arrayUnion($ary1, $ary2, $et);" } }) } From a2c7dd1c63e5f3d8dabd324df79791bdaf3e9fb8 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 18 Apr 2018 15:16:44 +0100 Subject: [PATCH 07/34] fix python test failure --- python/pyspark/sql/functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 20931bf0a90d3..025c7a9f95732 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2027,7 +2027,7 @@ def array_union(col1, col2): >>> from pyspark.sql import Row >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])]) >>> df.select(array_union(df.c1, df.c2)).collect() - [Row(array_union(c1, c2)=[u'b', u'c', u'd', u'a', u'f']))] + [Row(array_union(c1, c2)=[u'b', u'a', u'c', u'd', u'f']))] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.array_union(_to_java_column(col1), _to_java_column(col2))) From 53136802b9b7094b7f4cec4a4ec83928c5c6b3ad Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 18 Apr 2018 19:18:48 +0100 Subject: [PATCH 08/34] simplification --- .../expressions/collectionOperations.scala | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 3cdfcdbc1d3cf..8c581b9b6ca65 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3267,30 +3267,24 @@ abstract class ArraySetUtils extends BinaryExpression with ExpectsInputTypes { def typeId: Int - def array1: Expression - - def array2: Expression - override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, ArrayType) override def checkInputDataTypes(): TypeCheckResult = { val r = super.checkInputDataTypes() if ((r == TypeCheckResult.TypeCheckSuccess) && - (array1.dataType.asInstanceOf[ArrayType].elementType != - array2.dataType.asInstanceOf[ArrayType].elementType)) { + (left.dataType.asInstanceOf[ArrayType].elementType != + right.dataType.asInstanceOf[ArrayType].elementType)) { TypeCheckResult.TypeCheckFailure("Element type in both arrays must be the same") } else { r } } - override def dataType: DataType = array1.dataType + override def dataType: DataType = left.dataType private def elementType = dataType.asInstanceOf[ArrayType].elementType - - private def cn1 = array1.dataType.asInstanceOf[ArrayType].containsNull - - private def cn2 = array2.dataType.asInstanceOf[ArrayType].containsNull + private def cn1 = left.dataType.asInstanceOf[ArrayType].containsNull + private def cn2 = right.dataType.asInstanceOf[ArrayType].containsNull override def nullSafeEval(input1: Any, input2: Any): Any = { val ary1 = input1.asInstanceOf[ArrayData] @@ -3412,8 +3406,6 @@ object ArraySetUtils { since = "2.4.0") case class ArrayUnion(left: Expression, right: Expression) extends ArraySetUtils { override def typeId: Int = kindUnion - override def array1: Expression = left - override def array2: Expression = right override def prettyName: String = "array_union" } From 98f8d1fdb5b9fc9864b4f5a310c05dc2080204be Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 19 Apr 2018 03:11:01 +0100 Subject: [PATCH 09/34] fix pyspark test failure --- python/pyspark/sql/functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 025c7a9f95732..4e3eb07c26f31 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2027,7 +2027,7 @@ def array_union(col1, col2): >>> from pyspark.sql import Row >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])]) >>> df.select(array_union(df.c1, df.c2)).collect() - [Row(array_union(c1, c2)=[u'b', u'a', u'c', u'd', u'f']))] + [Row(array_union(c1, c2)=[u'b', u'a', u'c', u'd', u'f'])] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.array_union(_to_java_column(col1), _to_java_column(col2))) From 30ee7fca9705187cdb1527343f4de8a2e7275736 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 20 Apr 2018 08:54:09 +0100 Subject: [PATCH 10/34] address review comments --- .../expressions/collectionOperations.scala | 204 +++++++++++++----- .../spark/sql/DataFrameFunctionsSuite.scala | 22 +- 2 files changed, 174 insertions(+), 52 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 8c581b9b6ca65..9f71a6a039438 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3262,92 +3262,134 @@ case class ArrayDistinct(child: Expression) override def prettyName: String = "array_distinct" } -abstract class ArraySetUtils extends BinaryExpression with ExpectsInputTypes { +object ArraySetUtils { val kindUnion = 1 + val kindIntersect = 2 + val kindExcept = 3 + + def toUnsafeIntArray(hs: OpenHashSet[Int]): UnsafeArrayData = { + val array = new Array[Int](hs.size) + var pos = hs.nextPos(0) + var i = 0 + while (pos != OpenHashSet.INVALID_POS) { + array(i) = hs.getValue(pos) + pos = hs.nextPos(pos + 1) + i += 1 + } + UnsafeArrayData.fromPrimitiveArray(array) + } + + def toUnsafeLongArray(hs: OpenHashSet[Long]): UnsafeArrayData = { + val array = new Array[Long](hs.size) + var pos = hs.nextPos(0) + var i = 0 + while (pos != OpenHashSet.INVALID_POS) { + array(i) = hs.getValue(pos) + pos = hs.nextPos(pos + 1) + i += 1 + } + UnsafeArrayData.fromPrimitiveArray(array) + } + + def arrayUnion(array1: ArrayData, array2: ArrayData, et: DataType): ArrayData = { + new GenericArrayData(array1.toArray[AnyRef](et).union(array2.toArray[AnyRef](et)) + .distinct.asInstanceOf[Array[Any]]) + } + + def arrayIntersect(array1: ArrayData, array2: ArrayData, et: DataType): ArrayData = { + new GenericArrayData(array1.toArray[AnyRef](et).intersect(array2.toArray[AnyRef](et)) + .distinct.asInstanceOf[Array[Any]]) + } + def arrayExcept(array1: ArrayData, array2: ArrayData, et: DataType): ArrayData = { + new GenericArrayData(array2.toArray[AnyRef](et).diff(array1.toArray[AnyRef](et)) + .distinct.asInstanceOf[Array[Any]]) + } +} + +abstract class ArraySetUtils extends BinaryExpression with ExpectsInputTypes { def typeId: Int + def array1: Expression + def array2: Expression override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, ArrayType) override def checkInputDataTypes(): TypeCheckResult = { val r = super.checkInputDataTypes() if ((r == TypeCheckResult.TypeCheckSuccess) && - (left.dataType.asInstanceOf[ArrayType].elementType != - right.dataType.asInstanceOf[ArrayType].elementType)) { + (array1.dataType.asInstanceOf[ArrayType].elementType != + array2.dataType.asInstanceOf[ArrayType].elementType)) { TypeCheckResult.TypeCheckFailure("Element type in both arrays must be the same") } else { r } } - override def dataType: DataType = left.dataType + override def dataType: DataType = array1.dataType private def elementType = dataType.asInstanceOf[ArrayType].elementType - private def cn1 = left.dataType.asInstanceOf[ArrayType].containsNull - private def cn2 = right.dataType.asInstanceOf[ArrayType].containsNull + private def cn = array1.dataType.asInstanceOf[ArrayType].containsNull || + array2.dataType.asInstanceOf[ArrayType].containsNull + + def intEval(ary: ArrayData, hs1: OpenHashSet[Int]): OpenHashSet[Int] + def longEval(ary: ArrayData, hs1: OpenHashSet[Long]): OpenHashSet[Long] + def genericEval(ary: ArrayData, hs1: OpenHashSet[Any], et: DataType): OpenHashSet[Any] + def codeGen(ctx: CodegenContext, hs1: String, hs: String, len: String, getter: String, i: String, + postFix: String, newOpenHashSet: String): String override def nullSafeEval(input1: Any, input2: Any): Any = { val ary1 = input1.asInstanceOf[ArrayData] val ary2 = input2.asInstanceOf[ArrayData] - if (!cn1 && !cn2) { + if (!cn) { elementType match { case IntegerType => // avoid boxing of primitive int array elements - val hs = new OpenHashSet[Int] + val hs1 = new OpenHashSet[Int] var i = 0 while (i < ary1.numElements()) { - hs.add(ary1.getInt(i)) + hs1.add(ary1.getInt(i)) i += 1 } - i = 0 - while (i < ary2.numElements()) { - hs.add(ary2.getInt(i)) - i += 1 - } - UnsafeArrayData.fromPrimitiveArray(hs.iterator.toArray) + ArraySetUtils.toUnsafeIntArray(intEval(ary2, hs1)) case LongType => // avoid boxing of primitive long array elements - val hs = new OpenHashSet[Long] + val hs1 = new OpenHashSet[Long] var i = 0 while (i < ary1.numElements()) { - hs.add(ary1.getLong(i)) + hs1.add(ary1.getLong(i)) i += 1 } - i = 0 - while (i < ary2.numElements()) { - hs.add(ary2.getLong(i)) - i += 1 - } - UnsafeArrayData.fromPrimitiveArray(hs.iterator.toArray) + ArraySetUtils.toUnsafeLongArray(longEval(ary2, hs1)) case _ => - val hs = new OpenHashSet[Any] + var hs: OpenHashSet[Any] = null + val hs1 = new OpenHashSet[Any] var i = 0 while (i < ary1.numElements()) { - hs.add(ary1.get(i, elementType)) + hs1.add(ary1.get(i, elementType)) i += 1 } - i = 0 - while (i < ary2.numElements()) { - hs.add(ary2.get(i, elementType)) - i += 1 - } - new GenericArrayData(hs.iterator.toArray) + new GenericArrayData(genericEval(ary2, hs1, elementType).iterator.toArray) } } else { - ArraySetUtils.arrayUnion(ary1, ary2, elementType) + if (typeId == ArraySetUtils.kindUnion) { + ArraySetUtils.arrayUnion(ary1, ary2, elementType) + } else if (typeId == ArraySetUtils.kindIntersect) { + ArraySetUtils.arrayIntersect(ary1, ary2, elementType) + } else { + ArraySetUtils.arrayExcept(ary1, ary2, elementType) + } } } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val hs = ctx.freshName("hs") val i = ctx.freshName("i") - val ArraySetUtils = "org.apache.spark.sql.catalyst.expressions.ArraySetUtils" + val arraySetUtils = "org.apache.spark.sql.catalyst.expressions.ArraySetUtils" val genericArrayData = classOf[GenericArrayData].getName val unsafeArrayData = classOf[UnsafeArrayData].getName val openHashSet = classOf[OpenHashSet[_]].getName val et = s"org.apache.spark.sql.types.DataTypes.$elementType" - val (postFix, classTag, getter, arrayBuilder, castType) = if (!cn1 && !cn2) { + val (postFix, classTag, getter, arrayBuilder, javaTypeName) = if (!cn) { val ptName = CodeGenerator.primitiveTypeName(elementType) elementType match { case ByteType | ShortType | IntegerType => @@ -3364,32 +3406,46 @@ abstract class ArraySetUtils extends BinaryExpression with ExpectsInputTypes { ("", "", "", "", "") } + val hs = ctx.freshName("hs") + val hs1 = ctx.freshName("hs1") + val invalidPos = ctx.freshName("invalidPos") + val pos = ctx.freshName("pos") + val ary = ctx.freshName("ary") nullSafeCodeGen(ctx, ev, (ary1, ary2) => { if (classTag != "") { + val secondLoop = codeGen(ctx, hs1, hs, s"$ary2.numElements()", s"$ary2.$getter", i, + postFix, s"new $openHashSet$postFix($classTag)") s""" - |$openHashSet $hs = new $openHashSet$postFix($classTag); + |$openHashSet $hs1 = new $openHashSet$postFix($classTag); |for (int $i = 0; $i < $ary1.numElements(); $i++) { - | $hs.add$postFix($ary1.$getter); + | $hs1.add$postFix($ary1.$getter); |} - |for (int $i = 0; $i < $ary2.numElements(); $i++) { - | $hs.add$postFix($ary2.$getter); + |$secondLoop + |$javaTypeName[] $ary = new $javaTypeName[$hs.size()]; + |int $invalidPos = $openHashSet.INVALID_POS(); + |int $pos = $hs.nextPos(0); + |int $i = 0; + |while ($pos != $invalidPos) { + | $ary[$i] = ($javaTypeName) $hs.getValue$postFix($pos); + | $pos = $hs.nextPos($pos + 1); + | $i++; |} - |${ev.value} = $arrayBuilder(($castType[]) $hs.iterator().toArray($classTag)); + |${ev.value} = $arrayBuilder($ary); """.stripMargin } else { - s"${ev.value} = $ArraySetUtils$$.MODULE$$.arrayUnion($ary1, $ary2, $et);" + val setOp = if (typeId == ArraySetUtils.kindUnion) { + "Union" + } else if (typeId == ArraySetUtils.kindIntersect) { + "Intersect" + } else { + "Except" + } + s"${ev.value} = $arraySetUtils$$.MODULE$$.array$setOp($ary1, $ary2, $et);" } }) } } -object ArraySetUtils { - def arrayUnion(array1: ArrayData, array2: ArrayData, et: DataType): ArrayData = { - new GenericArrayData(array1.toArray[AnyRef](et).union(array2.toArray[AnyRef](et)) - .distinct.asInstanceOf[Array[Any]]) - } -} - /** * Returns an array of the elements in the union of x and y, without duplicates */ @@ -3405,7 +3461,57 @@ object ArraySetUtils { """, since = "2.4.0") case class ArrayUnion(left: Expression, right: Expression) extends ArraySetUtils { - override def typeId: Int = kindUnion + override def typeId: Int = ArraySetUtils.kindUnion + override def array1: Expression = left + override def array2: Expression = right + + override def intEval(ary: ArrayData, hs1: OpenHashSet[Int]): OpenHashSet[Int] = { + var i = 0 + while (i < ary.numElements()) { + hs1.add(ary.getInt(i)) + i += 1 + } + hs1 + } + + override def longEval(ary: ArrayData, hs1: OpenHashSet[Long]): OpenHashSet[Long] = { + var i = 0 + while (i < ary.numElements()) { + hs1.add(ary.getLong(i)) + i += 1 + } + hs1 + } + + override def genericEval( + ary: ArrayData, + hs1: OpenHashSet[Any], + et: DataType): OpenHashSet[Any] = { + var i = 0 + while (i < ary.numElements()) { + hs1.add(ary.get(i, et)) + i += 1 + } + hs1 + } + + override def codeGen( + ctx: CodegenContext, + hs1: String, + hs: String, + len: String, + getter: String, + i: String, + postFix: String, + newOpenHashSet: String): String = { + val openHashSet = classOf[OpenHashSet[_]].getName + s""" + |for (int $i = 0; $i < $len; $i++) { + | $hs1.add$postFix($getter); + |} + |$openHashSet $hs = $hs1; + """.stripMargin + } override def prettyName: String = "array_union" } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 1a1b91024be88..55f747af1a3f2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1152,12 +1152,28 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(df6.select(array_union($"a", $"b")), ans6) checkAnswer(df6.selectExpr("array_union(a, b)"), ans6) - val df0 = Seq((Array(1), Array("a"))).toDF("a", "b") + val df7 = Seq((Array(1), Array("a"))).toDF("a", "b") intercept[AnalysisException] { - df0.select(array_union($"a", $"b")) + df7.select(array_union($"a", $"b")) } intercept[AnalysisException] { - df0.selectExpr("array_contains(a, b)") + df7.selectExpr("array_contains(a, b)") + } + + val df8 = Seq((null, Array("a"))).toDF("a", "b") + intercept[AnalysisException] { + df8.select(array_union($"a", $"b")) + } + intercept[AnalysisException] { + df8.selectExpr("array_contains(a, b)") + } + + val df9 = Seq((Array("a"), null)).toDF("a", "b") + intercept[AnalysisException] { + df9.select(array_union($"a", $"b")) + } + intercept[AnalysisException] { + df9.selectExpr("array_contains(a, b)") } } From cd347e9291ca026cc864745e9991d70a6d12b2f2 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 20 Apr 2018 10:17:37 +0100 Subject: [PATCH 11/34] add new tests based on review comment --- .../sql/catalyst/expressions/CollectionExpressionsSuite.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 332dff183ac99..1ef6ae9c662f4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1189,6 +1189,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val a24 = Literal.create(Seq("c", "d", "a", "f"), ArrayType(StringType, false)) val a30 = Literal.create(Seq(null, null), ArrayType(NullType)) + val a31 = Literal.create(null, ArrayType(StringType)) checkEvaluation(ArrayUnion(a00, a01), UnsafeArrayData.fromPrimitiveArray(Array(4, 1, 3, 2))) checkEvaluation(ArrayUnion(a01, a02), Seq(4, 2, 1, 3)) @@ -1206,5 +1207,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayUnion(a23, a24), Seq("b", "c", "d", "a", "f")) checkEvaluation(ArrayUnion(a30, a30), Seq(null)) + checkEvaluation(ArrayUnion(a20, a31), null) + checkEvaluation(ArrayUnion(a31, a20), null) } } From d2eaee3ac5b770dd2f8bda2c5786e6ddc18327ee Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 20 Apr 2018 11:21:01 +0100 Subject: [PATCH 12/34] fix mistakes in rebase --- .../expressions/collectionOperations.scala | 38 ++++++++++++------- .../spark/sql/DataFrameFunctionsSuite.scala | 1 + 2 files changed, 25 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 9f71a6a039438..e526965114f4d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -2021,7 +2021,7 @@ case class Concat(children: Seq[Expression]) extends Expression { ByteArray.concat(inputs: _*) case StringType => val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) - UTF8String.concat(inputs: _*) + UTF8String.concat(inputs : _*) case ArrayType(elementType, _) => val inputs = children.toStream.map(_.eval(input)) if (inputs.contains(null)) { @@ -2035,7 +2035,7 @@ case class Concat(children: Seq[Expression]) extends Expression { } val finalData = new Array[AnyRef](numberOfElements.toInt) var position = 0 - for (ad <- arrayData) { + for(ad <- arrayData) { val arr = ad.toObjectArray(elementType) Array.copy(arr, 0, finalData, position, arr.length) position += arr.length @@ -2082,24 +2082,23 @@ case class Concat(children: Seq[Expression]) extends Expression { """) } - private def genCodeForNumberOfElements(ctx: CodegenContext): (String, String) = { + private def genCodeForNumberOfElements(ctx: CodegenContext) : (String, String) = { val numElements = ctx.freshName("numElements") - val code = - s""" - |long $numElements = 0L; - |for (int z = 0; z < ${children.length}; z++) { - | $numElements += args[z].numElements(); - |} - |if ($numElements > $MAX_ARRAY_LENGTH) { - | throw new RuntimeException("Unsuccessful try to concat arrays with " + $numElements + - | " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH."); - |} + val code = s""" + |long $numElements = 0L; + |for (int z = 0; z < ${children.length}; z++) { + | $numElements += args[z].numElements(); + |} + |if ($numElements > $MAX_ARRAY_LENGTH) { + | throw new RuntimeException("Unsuccessful try to concat arrays with " + $numElements + + | " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH."); + |} """.stripMargin (code, numElements) } - private def nullArgumentProtection(): String = { + private def nullArgumentProtection() : String = { if (nullable) { s""" |for (int z = 0; z < ${children.length}; z++) { @@ -2117,6 +2116,17 @@ case class Concat(children: Seq[Expression]) extends Expression { val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx) + val unsafeArraySizeInBytes = s""" + |long $arraySizeName = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( + | $numElemName, + | ${elementType.defaultSize}); + |if ($arraySizeName > $MAX_ARRAY_LENGTH) { + | throw new RuntimeException("Unsuccessful try to concat arrays with " + $arraySizeName + + | " bytes of data due to exceeding the limit $MAX_ARRAY_LENGTH bytes" + + | " for UnsafeArrayData."); + |} + """.stripMargin + val baseOffset = Platform.BYTE_ARRAY_OFFSET val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) s""" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 55f747af1a3f2..82ec1219d116a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1161,6 +1161,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } val df8 = Seq((null, Array("a"))).toDF("a", "b") + df8.select(array_union($"a", $"b")) intercept[AnalysisException] { df8.select(array_union($"a", $"b")) } From 2ddeb0619612645dbd79b09b06b5f053d4975253 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 20 Apr 2018 11:29:34 +0100 Subject: [PATCH 13/34] fix unexpected changes --- python/pyspark/sql/functions.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 4e3eb07c26f31..a25d75f5776b4 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1942,7 +1942,6 @@ def concat(*cols): return Column(sc._jvm.functions.concat(_to_seq(sc, cols, _to_java_column))) -@ignore_unicode_prefix @since(2.4) def array_position(col, value): """ From 71b31f00644756c016af2518dedec608ebb7cb01 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 20 Apr 2018 14:14:59 +0100 Subject: [PATCH 14/34] merge changes in #21103 fix test failure --- .../expressions/collectionOperations.scala | 93 ++++++++----------- .../spark/sql/DataFrameFunctionsSuite.scala | 1 - 2 files changed, 41 insertions(+), 53 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index e526965114f4d..e56ff7c6c606c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3274,8 +3274,6 @@ case class ArrayDistinct(child: Expression) object ArraySetUtils { val kindUnion = 1 - val kindIntersect = 2 - val kindExcept = 3 def toUnsafeIntArray(hs: OpenHashSet[Int]): UnsafeArrayData = { val array = new Array[Int](hs.size) @@ -3312,39 +3310,37 @@ object ArraySetUtils { } def arrayExcept(array1: ArrayData, array2: ArrayData, et: DataType): ArrayData = { - new GenericArrayData(array2.toArray[AnyRef](et).diff(array1.toArray[AnyRef](et)) + new GenericArrayData(array1.toArray[AnyRef](et).diff(array2.toArray[AnyRef](et)) .distinct.asInstanceOf[Array[Any]]) } } abstract class ArraySetUtils extends BinaryExpression with ExpectsInputTypes { def typeId: Int - def array1: Expression - def array2: Expression override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, ArrayType) override def checkInputDataTypes(): TypeCheckResult = { val r = super.checkInputDataTypes() if ((r == TypeCheckResult.TypeCheckSuccess) && - (array1.dataType.asInstanceOf[ArrayType].elementType != - array2.dataType.asInstanceOf[ArrayType].elementType)) { + (left.dataType.asInstanceOf[ArrayType].elementType != + right.dataType.asInstanceOf[ArrayType].elementType)) { TypeCheckResult.TypeCheckFailure("Element type in both arrays must be the same") } else { r } } - override def dataType: DataType = array1.dataType + override def dataType: DataType = left.dataType private def elementType = dataType.asInstanceOf[ArrayType].elementType - private def cn = array1.dataType.asInstanceOf[ArrayType].containsNull || - array2.dataType.asInstanceOf[ArrayType].containsNull + private def cn = left.dataType.asInstanceOf[ArrayType].containsNull || + right.dataType.asInstanceOf[ArrayType].containsNull - def intEval(ary: ArrayData, hs1: OpenHashSet[Int]): OpenHashSet[Int] - def longEval(ary: ArrayData, hs1: OpenHashSet[Long]): OpenHashSet[Long] - def genericEval(ary: ArrayData, hs1: OpenHashSet[Any], et: DataType): OpenHashSet[Any] - def codeGen(ctx: CodegenContext, hs1: String, hs: String, len: String, getter: String, i: String, + def intEval(ary: ArrayData, hs2: OpenHashSet[Int]): OpenHashSet[Int] + def longEval(ary: ArrayData, hs2: OpenHashSet[Long]): OpenHashSet[Long] + def genericEval(ary: ArrayData, hs2: OpenHashSet[Any], et: DataType): OpenHashSet[Any] + def codeGen(ctx: CodegenContext, hs2: String, hs: String, len: String, getter: String, i: String, postFix: String, newOpenHashSet: String): String override def nullSafeEval(input1: Any, input2: Any): Any = { @@ -3355,39 +3351,36 @@ abstract class ArraySetUtils extends BinaryExpression with ExpectsInputTypes { elementType match { case IntegerType => // avoid boxing of primitive int array elements - val hs1 = new OpenHashSet[Int] + val hs2 = new OpenHashSet[Int] var i = 0 - while (i < ary1.numElements()) { - hs1.add(ary1.getInt(i)) + while (i < ary2.numElements()) { + hs2.add(ary2.getInt(i)) i += 1 } - ArraySetUtils.toUnsafeIntArray(intEval(ary2, hs1)) + ArraySetUtils.toUnsafeIntArray(intEval(ary1, hs2)) case LongType => // avoid boxing of primitive long array elements - val hs1 = new OpenHashSet[Long] + val hs2 = new OpenHashSet[Long] var i = 0 - while (i < ary1.numElements()) { - hs1.add(ary1.getLong(i)) + while (i < ary2.numElements()) { + hs2.add(ary2.getLong(i)) i += 1 } - ArraySetUtils.toUnsafeLongArray(longEval(ary2, hs1)) + ArraySetUtils.toUnsafeLongArray(longEval(ary1, hs2)) case _ => - var hs: OpenHashSet[Any] = null - val hs1 = new OpenHashSet[Any] + val hs2 = new OpenHashSet[Any] var i = 0 - while (i < ary1.numElements()) { - hs1.add(ary1.get(i, elementType)) + while (i < ary2.numElements()) { + hs2.add(ary2.get(i, elementType)) i += 1 } - new GenericArrayData(genericEval(ary2, hs1, elementType).iterator.toArray) + new GenericArrayData(genericEval(ary1, hs2, elementType).iterator.toArray) } } else { if (typeId == ArraySetUtils.kindUnion) { ArraySetUtils.arrayUnion(ary1, ary2, elementType) - } else if (typeId == ArraySetUtils.kindIntersect) { - ArraySetUtils.arrayIntersect(ary1, ary2, elementType) } else { - ArraySetUtils.arrayExcept(ary1, ary2, elementType) + null } } } @@ -3417,18 +3410,18 @@ abstract class ArraySetUtils extends BinaryExpression with ExpectsInputTypes { } val hs = ctx.freshName("hs") - val hs1 = ctx.freshName("hs1") + val hs2 = ctx.freshName("hs2") val invalidPos = ctx.freshName("invalidPos") val pos = ctx.freshName("pos") val ary = ctx.freshName("ary") nullSafeCodeGen(ctx, ev, (ary1, ary2) => { if (classTag != "") { - val secondLoop = codeGen(ctx, hs1, hs, s"$ary2.numElements()", s"$ary2.$getter", i, + val secondLoop = codeGen(ctx, hs2, hs, s"$ary1.numElements()", s"$ary1.$getter", i, postFix, s"new $openHashSet$postFix($classTag)") s""" - |$openHashSet $hs1 = new $openHashSet$postFix($classTag); - |for (int $i = 0; $i < $ary1.numElements(); $i++) { - | $hs1.add$postFix($ary1.$getter); + |$openHashSet $hs2 = new $openHashSet$postFix($classTag); + |for (int $i = 0; $i < $ary2.numElements(); $i++) { + | $hs2.add$postFix($ary2.$getter); |} |$secondLoop |$javaTypeName[] $ary = new $javaTypeName[$hs.size()]; @@ -3445,10 +3438,8 @@ abstract class ArraySetUtils extends BinaryExpression with ExpectsInputTypes { } else { val setOp = if (typeId == ArraySetUtils.kindUnion) { "Union" - } else if (typeId == ArraySetUtils.kindIntersect) { - "Intersect" } else { - "Except" + "" } s"${ev.value} = $arraySetUtils$$.MODULE$$.array$setOp($ary1, $ary2, $et);" } @@ -3472,42 +3463,40 @@ abstract class ArraySetUtils extends BinaryExpression with ExpectsInputTypes { since = "2.4.0") case class ArrayUnion(left: Expression, right: Expression) extends ArraySetUtils { override def typeId: Int = ArraySetUtils.kindUnion - override def array1: Expression = left - override def array2: Expression = right - override def intEval(ary: ArrayData, hs1: OpenHashSet[Int]): OpenHashSet[Int] = { + override def intEval(ary: ArrayData, hs2: OpenHashSet[Int]): OpenHashSet[Int] = { var i = 0 while (i < ary.numElements()) { - hs1.add(ary.getInt(i)) + hs2.add(ary.getInt(i)) i += 1 } - hs1 + hs2 } - override def longEval(ary: ArrayData, hs1: OpenHashSet[Long]): OpenHashSet[Long] = { + override def longEval(ary: ArrayData, hs2: OpenHashSet[Long]): OpenHashSet[Long] = { var i = 0 while (i < ary.numElements()) { - hs1.add(ary.getLong(i)) + hs2.add(ary.getLong(i)) i += 1 } - hs1 + hs2 } override def genericEval( ary: ArrayData, - hs1: OpenHashSet[Any], + hs2: OpenHashSet[Any], et: DataType): OpenHashSet[Any] = { var i = 0 while (i < ary.numElements()) { - hs1.add(ary.get(i, et)) + hs2.add(ary.get(i, et)) i += 1 } - hs1 + hs2 } override def codeGen( ctx: CodegenContext, - hs1: String, + hs2: String, hs: String, len: String, getter: String, @@ -3517,9 +3506,9 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetUtils val openHashSet = classOf[OpenHashSet[_]].getName s""" |for (int $i = 0; $i < $len; $i++) { - | $hs1.add$postFix($getter); + | $hs2.add$postFix($getter); |} - |$openHashSet $hs = $hs1; + |$openHashSet $hs = $hs2; """.stripMargin } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 82ec1219d116a..55f747af1a3f2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1161,7 +1161,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } val df8 = Seq((null, Array("a"))).toDF("a", "b") - df8.select(array_union($"a", $"b")) intercept[AnalysisException] { df8.select(array_union($"a", $"b")) } From 7e71340d52ffa1204bcd22841162dc8a8584e7a7 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 4 May 2018 02:24:39 +0100 Subject: [PATCH 15/34] use GenericArrayData if UnsafeArrayData cannot be used use ctx.addReferenceObj for DataType --- .../expressions/collectionOperations.scala | 62 ++++++++++++++----- 1 file changed, 45 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index e56ff7c6c606c..4847067c4a181 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3275,7 +3275,7 @@ case class ArrayDistinct(child: Expression) object ArraySetUtils { val kindUnion = 1 - def toUnsafeIntArray(hs: OpenHashSet[Int]): UnsafeArrayData = { + def toArrayDataInt(hs: OpenHashSet[Int]): ArrayData = { val array = new Array[Int](hs.size) var pos = hs.nextPos(0) var i = 0 @@ -3284,10 +3284,18 @@ object ArraySetUtils { pos = hs.nextPos(pos + 1) i += 1 } - UnsafeArrayData.fromPrimitiveArray(array) + + val numBytes = 4L * array.length + val unsafeArraySizeInBytes = UnsafeArrayData.calculateHeaderPortionInBytes(array.length) + + org.apache.spark.unsafe.array.ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) + if (unsafeArraySizeInBytes <= Integer.MAX_VALUE) { + UnsafeArrayData.fromPrimitiveArray(array) + } else { + new GenericArrayData(array) + } } - def toUnsafeLongArray(hs: OpenHashSet[Long]): UnsafeArrayData = { + def toArrayDataLong(hs: OpenHashSet[Long]): ArrayData = { val array = new Array[Long](hs.size) var pos = hs.nextPos(0) var i = 0 @@ -3296,7 +3304,15 @@ object ArraySetUtils { pos = hs.nextPos(pos + 1) i += 1 } - UnsafeArrayData.fromPrimitiveArray(array) + + val numBytes = 8L * array.length + val unsafeArraySizeInBytes = UnsafeArrayData.calculateHeaderPortionInBytes(array.length) + + org.apache.spark.unsafe.array.ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) + if (unsafeArraySizeInBytes <= Integer.MAX_VALUE) { + UnsafeArrayData.fromPrimitiveArray(array) + } else { + new GenericArrayData(array) + } } def arrayUnion(array1: ArrayData, array2: ArrayData, et: DataType): ArrayData = { @@ -3357,7 +3373,7 @@ abstract class ArraySetUtils extends BinaryExpression with ExpectsInputTypes { hs2.add(ary2.getInt(i)) i += 1 } - ArraySetUtils.toUnsafeIntArray(intEval(ary1, hs2)) + ArraySetUtils.toArrayDataInt(intEval(ary1, hs2)) case LongType => // avoid boxing of primitive long array elements val hs2 = new OpenHashSet[Long] @@ -3366,7 +3382,7 @@ abstract class ArraySetUtils extends BinaryExpression with ExpectsInputTypes { hs2.add(ary2.getLong(i)) i += 1 } - ArraySetUtils.toUnsafeLongArray(longEval(ary1, hs2)) + ArraySetUtils.toArrayDataLong(longEval(ary1, hs2)) case _ => val hs2 = new OpenHashSet[Any] var i = 0 @@ -3387,23 +3403,34 @@ abstract class ArraySetUtils extends BinaryExpression with ExpectsInputTypes { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val i = ctx.freshName("i") + val ary = ctx.freshName("ary") val arraySetUtils = "org.apache.spark.sql.catalyst.expressions.ArraySetUtils" val genericArrayData = classOf[GenericArrayData].getName val unsafeArrayData = classOf[UnsafeArrayData].getName val openHashSet = classOf[OpenHashSet[_]].getName - val et = s"org.apache.spark.sql.types.DataTypes.$elementType" - val (postFix, classTag, getter, arrayBuilder, javaTypeName) = if (!cn) { + val (postFix, classTag, getter, javaTypeName, arrayBuilder) = if (!cn) { val ptName = CodeGenerator.primitiveTypeName(elementType) elementType match { - case ByteType | ShortType | IntegerType => - (s"$$mcI$$sp", s"scala.reflect.ClassTag$$.MODULE$$.$ptName()", s"get$ptName($i)", - s"$unsafeArrayData.fromPrimitiveArray", CodeGenerator.javaType(elementType)) - case LongType => - (s"$$mcJ$$sp", s"scala.reflect.ClassTag$$.MODULE$$.$ptName()", s"get$ptName($i)", - s"$unsafeArrayData.fromPrimitiveArray", "long") + case ByteType | ShortType | IntegerType | LongType => + (if (elementType == LongType) s"$$mcJ$$sp" else s"$$mcI$$sp", + s"scala.reflect.ClassTag$$.MODULE$$.$ptName()", s"get$ptName($i)", + CodeGenerator.javaType(elementType), + s""" + |long numBytes = (long) ${elementType.defaultSize} * $ary.length; + |long unsafeArraySizeInBytes = + | $unsafeArrayData.calculateHeaderPortionInBytes($ary.length) + + | org.apache.spark.unsafe.array.ByteArrayMethods + | .roundNumberOfBytesToNearestWord(numBytes); + |if (unsafeArraySizeInBytes <= Integer.MAX_VALUE) { + | ${ev.value} = $unsafeArrayData.fromPrimitiveArray($ary); + |} else { + | ${ev.value} = new $genericArrayData($ary); + |} + """.stripMargin) case _ => + val et = ctx.addReferenceObj("elementType", elementType) ("", s"scala.reflect.ClassTag$$.MODULE$$.Object()", s"get($i, $et)", - s"new $genericArrayData", "Object") + "Object", s"${ev.value} = new $genericArrayData($ary);") } } else { ("", "", "", "", "") @@ -3413,7 +3440,6 @@ abstract class ArraySetUtils extends BinaryExpression with ExpectsInputTypes { val hs2 = ctx.freshName("hs2") val invalidPos = ctx.freshName("invalidPos") val pos = ctx.freshName("pos") - val ary = ctx.freshName("ary") nullSafeCodeGen(ctx, ev, (ary1, ary2) => { if (classTag != "") { val secondLoop = codeGen(ctx, hs2, hs, s"$ary1.numElements()", s"$ary1.$getter", i, @@ -3433,7 +3459,8 @@ abstract class ArraySetUtils extends BinaryExpression with ExpectsInputTypes { | $pos = $hs.nextPos($pos + 1); | $i++; |} - |${ev.value} = $arrayBuilder($ary); + | + |$arrayBuilder """.stripMargin } else { val setOp = if (typeId == ArraySetUtils.kindUnion) { @@ -3441,6 +3468,7 @@ abstract class ArraySetUtils extends BinaryExpression with ExpectsInputTypes { } else { "" } + val et = ctx.addReferenceObj("elementTypeUtil", elementType) s"${ev.value} = $arraySetUtils$$.MODULE$$.array$setOp($ary1, $ary2, $et);" } }) From 04c97c3b5b62b29371569ba3faf43bba075ac343 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 4 May 2018 13:01:42 +0100 Subject: [PATCH 16/34] use BinaryArrayExpressionWithImplicitCast rename ArraySetUtils to ArraySetLike update an condition to use GenericArrayData --- .../expressions/collectionOperations.scala | 40 +++++++------------ 1 file changed, 14 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 4847067c4a181..5044d10820c0f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3272,7 +3272,7 @@ case class ArrayDistinct(child: Expression) override def prettyName: String = "array_distinct" } -object ArraySetUtils { +object ArraySetLike { val kindUnion = 1 def toArrayDataInt(hs: OpenHashSet[Int]): ArrayData = { @@ -3288,7 +3288,8 @@ object ArraySetUtils { val numBytes = 4L * array.length val unsafeArraySizeInBytes = UnsafeArrayData.calculateHeaderPortionInBytes(array.length) + org.apache.spark.unsafe.array.ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) - if (unsafeArraySizeInBytes <= Integer.MAX_VALUE) { + // Since UnsafeArrayData.fromPrimitiveArray() uses long[], max elements * 8 bytes can be used + if (unsafeArraySizeInBytes <= Integer.MAX_VALUE * 8) { UnsafeArrayData.fromPrimitiveArray(array) } else { new GenericArrayData(array) @@ -3308,7 +3309,8 @@ object ArraySetUtils { val numBytes = 8L * array.length val unsafeArraySizeInBytes = UnsafeArrayData.calculateHeaderPortionInBytes(array.length) + org.apache.spark.unsafe.array.ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) - if (unsafeArraySizeInBytes <= Integer.MAX_VALUE) { + // Since UnsafeArrayData.fromPrimitiveArray() uses long[], max elements * 8 bytes can be used + if (unsafeArraySizeInBytes <= Integer.MAX_VALUE * 8) { UnsafeArrayData.fromPrimitiveArray(array) } else { new GenericArrayData(array) @@ -3331,25 +3333,11 @@ object ArraySetUtils { } } -abstract class ArraySetUtils extends BinaryExpression with ExpectsInputTypes { +abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast { def typeId: Int - override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, ArrayType) - - override def checkInputDataTypes(): TypeCheckResult = { - val r = super.checkInputDataTypes() - if ((r == TypeCheckResult.TypeCheckSuccess) && - (left.dataType.asInstanceOf[ArrayType].elementType != - right.dataType.asInstanceOf[ArrayType].elementType)) { - TypeCheckResult.TypeCheckFailure("Element type in both arrays must be the same") - } else { - r - } - } - override def dataType: DataType = left.dataType - private def elementType = dataType.asInstanceOf[ArrayType].elementType private def cn = left.dataType.asInstanceOf[ArrayType].containsNull || right.dataType.asInstanceOf[ArrayType].containsNull @@ -3373,7 +3361,7 @@ abstract class ArraySetUtils extends BinaryExpression with ExpectsInputTypes { hs2.add(ary2.getInt(i)) i += 1 } - ArraySetUtils.toArrayDataInt(intEval(ary1, hs2)) + ArraySetLike.toArrayDataInt(intEval(ary1, hs2)) case LongType => // avoid boxing of primitive long array elements val hs2 = new OpenHashSet[Long] @@ -3382,7 +3370,7 @@ abstract class ArraySetUtils extends BinaryExpression with ExpectsInputTypes { hs2.add(ary2.getLong(i)) i += 1 } - ArraySetUtils.toArrayDataLong(longEval(ary1, hs2)) + ArraySetLike.toArrayDataLong(longEval(ary1, hs2)) case _ => val hs2 = new OpenHashSet[Any] var i = 0 @@ -3393,8 +3381,8 @@ abstract class ArraySetUtils extends BinaryExpression with ExpectsInputTypes { new GenericArrayData(genericEval(ary1, hs2, elementType).iterator.toArray) } } else { - if (typeId == ArraySetUtils.kindUnion) { - ArraySetUtils.arrayUnion(ary1, ary2, elementType) + if (typeId == ArraySetLike.kindUnion) { + ArraySetLike.arrayUnion(ary1, ary2, elementType) } else { null } @@ -3404,7 +3392,7 @@ abstract class ArraySetUtils extends BinaryExpression with ExpectsInputTypes { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val i = ctx.freshName("i") val ary = ctx.freshName("ary") - val arraySetUtils = "org.apache.spark.sql.catalyst.expressions.ArraySetUtils" + val arraySetUtils = "org.apache.spark.sql.catalyst.expressions.ArraySetLike" val genericArrayData = classOf[GenericArrayData].getName val unsafeArrayData = classOf[UnsafeArrayData].getName val openHashSet = classOf[OpenHashSet[_]].getName @@ -3463,7 +3451,7 @@ abstract class ArraySetUtils extends BinaryExpression with ExpectsInputTypes { |$arrayBuilder """.stripMargin } else { - val setOp = if (typeId == ArraySetUtils.kindUnion) { + val setOp = if (typeId == ArraySetLike.kindUnion) { "Union" } else { "" @@ -3489,8 +3477,8 @@ abstract class ArraySetUtils extends BinaryExpression with ExpectsInputTypes { array(1, 2, 3, 5) """, since = "2.4.0") -case class ArrayUnion(left: Expression, right: Expression) extends ArraySetUtils { - override def typeId: Int = ArraySetUtils.kindUnion +case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike { + override def typeId: Int = ArraySetLike.kindUnion override def intEval(ary: ArrayData, hs2: OpenHashSet[Int]): OpenHashSet[Int] = { var i = 0 From 401ca7a9678741e03597b065774165513cbe882b Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 4 May 2018 19:49:48 +0100 Subject: [PATCH 17/34] update test cases fix test failure --- .../spark/sql/DataFrameFunctionsSuite.scala | 23 ++++++++----------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 55f747af1a3f2..f83829d690d9f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1147,33 +1147,30 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(df5.select(array_union($"a", $"b")), ans5) checkAnswer(df5.selectExpr("array_union(a, b)"), ans5) - val df6 = Seq((null, null)).toDF("a", "b") - val ans6 = Row(null) + val df6 = Seq((Array(1), Array("a"))).toDF("a", "b") + val ans6 = Row(Seq("1", "a")) checkAnswer(df6.select(array_union($"a", $"b")), ans6) checkAnswer(df6.selectExpr("array_union(a, b)"), ans6) - val df7 = Seq((Array(1), Array("a"))).toDF("a", "b") - intercept[AnalysisException] { - df7.select(array_union($"a", $"b")) - } - intercept[AnalysisException] { - df7.selectExpr("array_contains(a, b)") - } + val df7 = Seq((null, Array("a"))).toDF("a", "b") + val ans7 = Row(null) + checkAnswer(df7.select(array_union($"a", $"b")), ans7) + checkAnswer(df7.selectExpr("array_union(a, b)"), ans7) - val df8 = Seq((null, Array("a"))).toDF("a", "b") + val df8 = Seq((null, null)).toDF("a", "b") intercept[AnalysisException] { df8.select(array_union($"a", $"b")) } intercept[AnalysisException] { - df8.selectExpr("array_contains(a, b)") + df8.selectExpr("array_union(a, b)") } - val df9 = Seq((Array("a"), null)).toDF("a", "b") + val df9 = Seq((Array(Array(1)), Array("a"))).toDF("a", "b") intercept[AnalysisException] { df9.select(array_union($"a", $"b")) } intercept[AnalysisException] { - df9.selectExpr("array_contains(a, b)") + df9.selectExpr("array_union(a, b)") } } From 15b953bea6f95b469abf309ebd76665f4a7ced74 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 17 May 2018 19:47:04 +0100 Subject: [PATCH 18/34] rebase with master --- .../expressions/collectionOperations.scala | 50 +++++++------------ .../spark/sql/DataFrameFunctionsSuite.scala | 28 +++++------ 2 files changed, 30 insertions(+), 48 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 5044d10820c0f..00a8dee0ff7a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -30,7 +30,6 @@ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.catalyst.util.DateTimeUtils._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH import org.apache.spark.unsafe.types.{ByteArray, UTF8String} @@ -2116,17 +2115,6 @@ case class Concat(children: Seq[Expression]) extends Expression { val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx) - val unsafeArraySizeInBytes = s""" - |long $arraySizeName = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( - | $numElemName, - | ${elementType.defaultSize}); - |if ($arraySizeName > $MAX_ARRAY_LENGTH) { - | throw new RuntimeException("Unsuccessful try to concat arrays with " + $arraySizeName + - | " bytes of data due to exceeding the limit $MAX_ARRAY_LENGTH bytes" + - | " for UnsafeArrayData."); - |} - """.stripMargin - val baseOffset = Platform.BYTE_ARRAY_OFFSET val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) s""" @@ -2869,6 +2857,7 @@ case class ArrayRepeat(left: Expression, right: Expression) |$arrayDataName = new $genericArrayClass($arrayName); """.stripMargin } + } /** @@ -3391,37 +3380,32 @@ abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val i = ctx.freshName("i") + val value = ctx.freshName("value") + val size = ctx.freshName("size") val ary = ctx.freshName("ary") val arraySetUtils = "org.apache.spark.sql.catalyst.expressions.ArraySetLike" val genericArrayData = classOf[GenericArrayData].getName - val unsafeArrayData = classOf[UnsafeArrayData].getName val openHashSet = classOf[OpenHashSet[_]].getName - val (postFix, classTag, getter, javaTypeName, arrayBuilder) = if (!cn) { + val (postFix, classTag, getter, setter, javaTypeName, arrayBuilder) = if (!cn) { val ptName = CodeGenerator.primitiveTypeName(elementType) elementType match { case ByteType | ShortType | IntegerType | LongType => + val uary = ctx.freshName("uary") (if (elementType == LongType) s"$$mcJ$$sp" else s"$$mcI$$sp", - s"scala.reflect.ClassTag$$.MODULE$$.$ptName()", s"get$ptName($i)", - CodeGenerator.javaType(elementType), + s"scala.reflect.ClassTag$$.MODULE$$.$ptName()", + s"get$ptName($i)", s"set$ptName($i, $value)", CodeGenerator.javaType(elementType), s""" - |long numBytes = (long) ${elementType.defaultSize} * $ary.length; - |long unsafeArraySizeInBytes = - | $unsafeArrayData.calculateHeaderPortionInBytes($ary.length) + - | org.apache.spark.unsafe.array.ByteArrayMethods - | .roundNumberOfBytesToNearestWord(numBytes); - |if (unsafeArraySizeInBytes <= Integer.MAX_VALUE) { - | ${ev.value} = $unsafeArrayData.fromPrimitiveArray($ary); - |} else { - | ${ev.value} = new $genericArrayData($ary); - |} + |${ctx.createUnsafeArray(uary, size, elementType, s" $prettyName failed.")} + |${ev.value} = $uary; """.stripMargin) case _ => val et = ctx.addReferenceObj("elementType", elementType) - ("", s"scala.reflect.ClassTag$$.MODULE$$.Object()", s"get($i, $et)", - "Object", s"${ev.value} = new $genericArrayData($ary);") + ("", s"scala.reflect.ClassTag$$.MODULE$$.Object()", + s"get($i, $et)", s"update($i, $value)", "Object", + s"${ev.value} = new $genericArrayData(new Object[$size]);") } } else { - ("", "", "", "", "") + ("", "", "", "", "", "") } val hs = ctx.freshName("hs") @@ -3438,17 +3422,17 @@ abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast { | $hs2.add$postFix($ary2.$getter); |} |$secondLoop - |$javaTypeName[] $ary = new $javaTypeName[$hs.size()]; + |int $size = $hs.size(); + |$arrayBuilder |int $invalidPos = $openHashSet.INVALID_POS(); |int $pos = $hs.nextPos(0); |int $i = 0; |while ($pos != $invalidPos) { - | $ary[$i] = ($javaTypeName) $hs.getValue$postFix($pos); + | $javaTypeName $value = ($javaTypeName) $hs.getValue$postFix($pos); + | ${ev.value}.$setter; | $pos = $hs.nextPos($pos + 1); | $i++; |} - | - |$arrayBuilder """.stripMargin } else { val setOp = if (typeId == ArraySetLike.kindUnion) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index f83829d690d9f..10180d75e9505 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1147,30 +1147,28 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(df5.select(array_union($"a", $"b")), ans5) checkAnswer(df5.selectExpr("array_union(a, b)"), ans5) - val df6 = Seq((Array(1), Array("a"))).toDF("a", "b") - val ans6 = Row(Seq("1", "a")) - checkAnswer(df6.select(array_union($"a", $"b")), ans6) - checkAnswer(df6.selectExpr("array_union(a, b)"), ans6) - - val df7 = Seq((null, Array("a"))).toDF("a", "b") - val ans7 = Row(null) - checkAnswer(df7.select(array_union($"a", $"b")), ans7) - checkAnswer(df7.selectExpr("array_union(a, b)"), ans7) + val df6 = Seq((null, Array("a"))).toDF("a", "b") + intercept[AnalysisException] { + df6.select(array_union($"a", $"b")) + } + intercept[AnalysisException] { + df6.selectExpr("array_union(a, b)") + } - val df8 = Seq((null, null)).toDF("a", "b") + val df7 = Seq((null, null)).toDF("a", "b") intercept[AnalysisException] { - df8.select(array_union($"a", $"b")) + df7.select(array_union($"a", $"b")) } intercept[AnalysisException] { - df8.selectExpr("array_union(a, b)") + df7.selectExpr("array_union(a, b)") } - val df9 = Seq((Array(Array(1)), Array("a"))).toDF("a", "b") + val df8 = Seq((Array(Array(1)), Array("a"))).toDF("a", "b") intercept[AnalysisException] { - df9.select(array_union($"a", $"b")) + df8.select(array_union($"a", $"b")) } intercept[AnalysisException] { - df9.selectExpr("array_union(a, b)") + df8.selectExpr("array_union(a, b)") } } From f05092242bfbed7d7fa335acc18fb7ee1a7ab6ae Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 18 May 2018 20:37:13 +0100 Subject: [PATCH 19/34] support complex types --- .../expressions/collectionOperations.scala | 90 +++++++++++++++---- .../CollectionExpressionsSuite.scala | 28 +++++- 2 files changed, 102 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 00a8dee0ff7a9..0728099a50159 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3264,6 +3264,8 @@ case class ArrayDistinct(child: Expression) object ArraySetLike { val kindUnion = 1 + private val MAX_ARRAY_LENGTH: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + def toArrayDataInt(hs: OpenHashSet[Int]): ArrayData = { val array = new Array[Int](hs.size) var pos = hs.nextPos(0) @@ -3306,19 +3308,55 @@ object ArraySetLike { } } - def arrayUnion(array1: ArrayData, array2: ArrayData, et: DataType): ArrayData = { - new GenericArrayData(array1.toArray[AnyRef](et).union(array2.toArray[AnyRef](et)) - .distinct.asInstanceOf[Array[Any]]) - } - - def arrayIntersect(array1: ArrayData, array2: ArrayData, et: DataType): ArrayData = { - new GenericArrayData(array1.toArray[AnyRef](et).intersect(array2.toArray[AnyRef](et)) - .distinct.asInstanceOf[Array[Any]]) - } - - def arrayExcept(array1: ArrayData, array2: ArrayData, et: DataType): ArrayData = { - new GenericArrayData(array1.toArray[AnyRef](et).diff(array2.toArray[AnyRef](et)) - .distinct.asInstanceOf[Array[Any]]) + def arrayUnion( + array1: ArrayData, + array2: ArrayData, + et: DataType, + ordering: Ordering[Any]): ArrayData = { + if (ordering == null) { + new GenericArrayData(array1.toObjectArray(et).union(array2.toObjectArray(et)) + .distinct.asInstanceOf[Array[Any]]) + } else { + val length = math.min(array1.numElements().toLong + array2.numElements().toLong, + ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) + val array = new Array[Any](length.toInt) + var hasNull = false + array1.foreach(et, (i, v) => { + array(i) = v + if (v == null) { + hasNull = true + } + }) + var pos = array1.numElements() + array2.foreach(et, (_, v) => { + var found = false + if (v == null) { + if (hasNull) { + found = true + } else { + hasNull = true + } + } else { + var j = 0 + while (!found && j < pos) { + val va = array(j) + if (va != null && ordering.equiv(va, v)) { + found = true + } + j = j + 1 + } + } + if (!found) { + if (pos > MAX_ARRAY_LENGTH) { + throw new RuntimeException(s"Unsuccessful try to union arrays with $pos" + + s" elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.") + } + array(pos) = v + pos = pos + 1 + } + }) + new GenericArrayData(array.slice(0, pos)) + } } } @@ -3327,9 +3365,28 @@ abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast { override def dataType: DataType = left.dataType + override def checkInputDataTypes(): TypeCheckResult = { + val typeCheckResult = super.checkInputDataTypes() + if (typeCheckResult.isSuccess) { + TypeUtils.checkForOrderingExpr(dataType.asInstanceOf[ArrayType].elementType, + s"function $prettyName") + } else { + typeCheckResult + } + } + private def cn = left.dataType.asInstanceOf[ArrayType].containsNull || right.dataType.asInstanceOf[ArrayType].containsNull + @transient private lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(elementType) + + @transient private lazy val elementTypeSupportEquals = elementType match { + case BinaryType => false + case _: AtomicType => true + case _ => false + } + def intEval(ary: ArrayData, hs2: OpenHashSet[Int]): OpenHashSet[Int] def longEval(ary: ArrayData, hs2: OpenHashSet[Long]): OpenHashSet[Long] def genericEval(ary: ArrayData, hs2: OpenHashSet[Any], et: DataType): OpenHashSet[Any] @@ -3371,7 +3428,8 @@ abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast { } } else { if (typeId == ArraySetLike.kindUnion) { - ArraySetLike.arrayUnion(ary1, ary2, elementType) + ArraySetLike.arrayUnion(ary1, ary2, elementType, + if (elementTypeSupportEquals) null else ordering) } else { null } @@ -3441,7 +3499,9 @@ abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast { "" } val et = ctx.addReferenceObj("elementTypeUtil", elementType) - s"${ev.value} = $arraySetUtils$$.MODULE$$.array$setOp($ary1, $ary2, $et);" + val order = if (elementTypeSupportEquals) "null" + else ctx.addReferenceObj("orderingUtil", ordering) + s"${ev.value} = $arraySetUtils$$.MODULE$$.array$setOp($ary1, $ary2, $et, $order);" } }) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 1ef6ae9c662f4..07067d697e1e7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1188,7 +1188,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val a23 = Literal.create(Seq("b", "a", "c"), ArrayType(StringType, false)) val a24 = Literal.create(Seq("c", "d", "a", "f"), ArrayType(StringType, false)) - val a30 = Literal.create(Seq(null, null), ArrayType(NullType)) + val a30 = Literal.create(Seq(null, null), ArrayType(IntegerType)) val a31 = Literal.create(null, ArrayType(StringType)) checkEvaluation(ArrayUnion(a00, a01), UnsafeArrayData.fromPrimitiveArray(Array(4, 1, 3, 2))) @@ -1209,5 +1209,31 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayUnion(a30, a30), Seq(null)) checkEvaluation(ArrayUnion(a20, a31), null) checkEvaluation(ArrayUnion(a31, a20), null) + + val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2)), + ArrayType(BinaryType)) + val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](4, 3)), + ArrayType(BinaryType)) + val b2 = Literal.create(Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](4, 3)), + ArrayType(BinaryType)) + val b3 = Literal.create(Seq[Array[Byte]](Array[Byte](1, 2), null), ArrayType(BinaryType)) + val b4 = Literal.create(Seq[Array[Byte]](null, Array[Byte](1, 2)), ArrayType(BinaryType)) + val arrayWithBinaryNull = Literal.create(Seq(null), ArrayType(BinaryType)) + + checkEvaluation(ArrayUnion(b0, b1), + Seq(Array[Byte](5, 6), Array[Byte](1, 2), Array[Byte](2, 1), Array[Byte](4, 3))) + checkEvaluation(ArrayUnion(b0, b2), + Seq(Array[Byte](5, 6), Array[Byte](1, 2), Array[Byte](4, 3))) + checkEvaluation(ArrayUnion(b2, b3), Seq(Array[Byte](1, 2), Array[Byte](4, 3), null)) + checkEvaluation(ArrayUnion(b3, b0), Seq(Array[Byte](1, 2), null, Array[Byte](5, 6))) + checkEvaluation(ArrayUnion(b3, b4), Seq(Array[Byte](1, 2), null)) + checkEvaluation(ArrayUnion(b3, arrayWithBinaryNull), Seq(Array[Byte](1, 2), null)) + + val aa0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)), + ArrayType(ArrayType(IntegerType))) + val aa1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)), + ArrayType(ArrayType(IntegerType))) + checkEvaluation(ArrayUnion(aa0, aa1), + Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4), Seq[Int](5, 6), Seq[Int](2, 1))) } } From 8a27667344eb6983bb3ee6f572ac5ab4a4012ddc Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sat, 19 May 2018 02:29:00 +0100 Subject: [PATCH 20/34] add test cases with duplication in an array --- .../expressions/collectionOperations.scala | 12 +++--------- .../CollectionExpressionsSuite.scala | 18 ++++++++++++------ 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 0728099a50159..de0f5bdfbd2d9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3320,15 +3320,9 @@ object ArraySetLike { val length = math.min(array1.numElements().toLong + array2.numElements().toLong, ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) val array = new Array[Any](length.toInt) + var pos = 0 var hasNull = false - array1.foreach(et, (i, v) => { - array(i) = v - if (v == null) { - hasNull = true - } - }) - var pos = array1.numElements() - array2.foreach(et, (_, v) => { + Seq(array1, array2).foreach(_.foreach(et, (_, v) => { var found = false if (v == null) { if (hasNull) { @@ -3354,7 +3348,7 @@ object ArraySetLike { array(pos) = v pos = pos + 1 } - }) + })) new GenericArrayData(array.slice(0, pos)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 07067d697e1e7..d8aacc63c87bb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1216,18 +1216,24 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper ArrayType(BinaryType)) val b2 = Literal.create(Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](4, 3)), ArrayType(BinaryType)) - val b3 = Literal.create(Seq[Array[Byte]](Array[Byte](1, 2), null), ArrayType(BinaryType)) - val b4 = Literal.create(Seq[Array[Byte]](null, Array[Byte](1, 2)), ArrayType(BinaryType)) + val b3 = Literal.create(Seq[Array[Byte]]( + Array[Byte](1, 2), Array[Byte](4, 3), Array[Byte](1, 2)), ArrayType(BinaryType)) + val b4 = Literal.create(Seq[Array[Byte]](Array[Byte](1, 2), null), ArrayType(BinaryType)) + val b5 = Literal.create(Seq[Array[Byte]](null, Array[Byte](1, 2)), ArrayType(BinaryType)) + val b6 = Literal.create(Seq.empty, ArrayType(BinaryType)) val arrayWithBinaryNull = Literal.create(Seq(null), ArrayType(BinaryType)) checkEvaluation(ArrayUnion(b0, b1), Seq(Array[Byte](5, 6), Array[Byte](1, 2), Array[Byte](2, 1), Array[Byte](4, 3))) checkEvaluation(ArrayUnion(b0, b2), Seq(Array[Byte](5, 6), Array[Byte](1, 2), Array[Byte](4, 3))) - checkEvaluation(ArrayUnion(b2, b3), Seq(Array[Byte](1, 2), Array[Byte](4, 3), null)) - checkEvaluation(ArrayUnion(b3, b0), Seq(Array[Byte](1, 2), null, Array[Byte](5, 6))) - checkEvaluation(ArrayUnion(b3, b4), Seq(Array[Byte](1, 2), null)) - checkEvaluation(ArrayUnion(b3, arrayWithBinaryNull), Seq(Array[Byte](1, 2), null)) + checkEvaluation(ArrayUnion(b2, b4), Seq(Array[Byte](1, 2), Array[Byte](4, 3), null)) + checkEvaluation(ArrayUnion(b3, b0), + Seq(Array[Byte](1, 2), Array[Byte](4, 3), Array[Byte](5, 6))) + checkEvaluation(ArrayUnion(b4, b0), Seq(Array[Byte](1, 2), null, Array[Byte](5, 6))) + checkEvaluation(ArrayUnion(b4, b5), Seq(Array[Byte](1, 2), null)) + checkEvaluation(ArrayUnion(b6, b4), Seq(Array[Byte](1, 2), null)) + checkEvaluation(ArrayUnion(b4, arrayWithBinaryNull), Seq(Array[Byte](1, 2), null)) val aa0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)), ArrayType(ArrayType(IntegerType))) From e50bc558439ecb15db116bf12708502de6f5e2e6 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 1 Jun 2018 09:14:31 +0100 Subject: [PATCH 21/34] rebase with master --- .../spark/sql/catalyst/expressions/collectionOperations.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index de0f5bdfbd2d9..b336764178226 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.catalyst.util.DateTimeUtils._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH import org.apache.spark.unsafe.types.{ByteArray, UTF8String} From 7e3f2ef92f1b4d13e1100995fcff608c0a7b760e Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 1 Jun 2018 10:16:05 +0100 Subject: [PATCH 22/34] address review comments --- .../expressions/collectionOperations.scala | 79 ++++++++----------- 1 file changed, 35 insertions(+), 44 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index b336764178226..6b79011663eab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3263,8 +3263,6 @@ case class ArrayDistinct(child: Expression) } object ArraySetLike { - val kindUnion = 1 - private val MAX_ARRAY_LENGTH: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH def toArrayDataInt(hs: OpenHashSet[Int]): ArrayData = { @@ -3277,9 +3275,9 @@ object ArraySetLike { i += 1 } - val numBytes = 4L * array.length + val numBytes = IntegerType.defaultSize.toLong * array.length val unsafeArraySizeInBytes = UnsafeArrayData.calculateHeaderPortionInBytes(array.length) + - org.apache.spark.unsafe.array.ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) // Since UnsafeArrayData.fromPrimitiveArray() uses long[], max elements * 8 bytes can be used if (unsafeArraySizeInBytes <= Integer.MAX_VALUE * 8) { UnsafeArrayData.fromPrimitiveArray(array) @@ -3298,9 +3296,9 @@ object ArraySetLike { i += 1 } - val numBytes = 8L * array.length + val numBytes = LongType.defaultSize.toLong * array.length val unsafeArraySizeInBytes = UnsafeArrayData.calculateHeaderPortionInBytes(array.length) + - org.apache.spark.unsafe.array.ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) // Since UnsafeArrayData.fromPrimitiveArray() uses long[], max elements * 8 bytes can be used if (unsafeArraySizeInBytes <= Integer.MAX_VALUE * 8) { UnsafeArrayData.fromPrimitiveArray(array) @@ -3356,7 +3354,7 @@ object ArraySetLike { } abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast { - def typeId: Int + def arraySetLikeOpName: String override def dataType: DataType = left.dataType @@ -3373,10 +3371,10 @@ abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast { private def cn = left.dataType.asInstanceOf[ArrayType].containsNull || right.dataType.asInstanceOf[ArrayType].containsNull - @transient private lazy val ordering: Ordering[Any] = + @transient protected lazy val ordering: Ordering[Any] = TypeUtils.getInterpretedOrdering(elementType) - @transient private lazy val elementTypeSupportEquals = elementType match { + @transient protected lazy val elementTypeSupportEquals = elementType match { case BinaryType => false case _: AtomicType => true case _ => false @@ -3384,7 +3382,8 @@ abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast { def intEval(ary: ArrayData, hs2: OpenHashSet[Int]): OpenHashSet[Int] def longEval(ary: ArrayData, hs2: OpenHashSet[Long]): OpenHashSet[Long] - def genericEval(ary: ArrayData, hs2: OpenHashSet[Any], et: DataType): OpenHashSet[Any] + def genericEval(ary: ArrayData, hs2: OpenHashSet[Any]): OpenHashSet[Any] + def genericEvalContainsNull(ary1: ArrayData, ary2: ArrayData): ArrayData def codeGen(ctx: CodegenContext, hs2: String, hs: String, len: String, getter: String, i: String, postFix: String, newOpenHashSet: String): String @@ -3396,38 +3395,33 @@ abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast { elementType match { case IntegerType => // avoid boxing of primitive int array elements - val hs2 = new OpenHashSet[Int] + val hs = new OpenHashSet[Int] var i = 0 while (i < ary2.numElements()) { - hs2.add(ary2.getInt(i)) + hs.add(ary2.getInt(i)) i += 1 } - ArraySetLike.toArrayDataInt(intEval(ary1, hs2)) + ArraySetLike.toArrayDataInt(intEval(ary1, hs)) case LongType => // avoid boxing of primitive long array elements - val hs2 = new OpenHashSet[Long] + val hs = new OpenHashSet[Long] var i = 0 while (i < ary2.numElements()) { - hs2.add(ary2.getLong(i)) + hs.add(ary2.getLong(i)) i += 1 } - ArraySetLike.toArrayDataLong(longEval(ary1, hs2)) + ArraySetLike.toArrayDataLong(longEval(ary1, hs)) case _ => - val hs2 = new OpenHashSet[Any] + val hs = new OpenHashSet[Any] var i = 0 while (i < ary2.numElements()) { - hs2.add(ary2.get(i, elementType)) + hs.add(ary2.get(i, elementType)) i += 1 } - new GenericArrayData(genericEval(ary1, hs2, elementType).iterator.toArray) + new GenericArrayData(genericEval(ary1, hs).iterator.toArray) } } else { - if (typeId == ArraySetLike.kindUnion) { - ArraySetLike.arrayUnion(ary1, ary2, elementType, - if (elementTypeSupportEquals) null else ordering) - } else { - null - } + genericEvalContainsNull(ary1, ary2) } } @@ -3488,15 +3482,10 @@ abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast { |} """.stripMargin } else { - val setOp = if (typeId == ArraySetLike.kindUnion) { - "Union" - } else { - "" - } val et = ctx.addReferenceObj("elementTypeUtil", elementType) val order = if (elementTypeSupportEquals) "null" else ctx.addReferenceObj("orderingUtil", ordering) - s"${ev.value} = $arraySetUtils$$.MODULE$$.array$setOp($ary1, $ary2, $et, $order);" + s"${ev.value} = $arraySetUtils$$.MODULE$$.$arraySetLikeOpName($ary1, $ary2, $et, $order);" } }) } @@ -3517,36 +3506,38 @@ abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast { """, since = "2.4.0") case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike { - override def typeId: Int = ArraySetLike.kindUnion + override def arraySetLikeOpName: String = "arrayUnion" - override def intEval(ary: ArrayData, hs2: OpenHashSet[Int]): OpenHashSet[Int] = { + override def intEval(ary: ArrayData, hs: OpenHashSet[Int]): OpenHashSet[Int] = { var i = 0 while (i < ary.numElements()) { - hs2.add(ary.getInt(i)) + hs.add(ary.getInt(i)) i += 1 } - hs2 + hs } - override def longEval(ary: ArrayData, hs2: OpenHashSet[Long]): OpenHashSet[Long] = { + override def longEval(ary: ArrayData, hs: OpenHashSet[Long]): OpenHashSet[Long] = { var i = 0 while (i < ary.numElements()) { - hs2.add(ary.getLong(i)) + hs.add(ary.getLong(i)) i += 1 } - hs2 + hs } - override def genericEval( - ary: ArrayData, - hs2: OpenHashSet[Any], - et: DataType): OpenHashSet[Any] = { + override def genericEval(ary: ArrayData, hs: OpenHashSet[Any]): OpenHashSet[Any] = { var i = 0 while (i < ary.numElements()) { - hs2.add(ary.get(i, et)) + hs.add(ary.get(i, elementType)) i += 1 } - hs2 + hs + } + + override def genericEvalContainsNull(ary1: ArrayData, ary2: ArrayData): ArrayData = { + ArraySetLike.arrayUnion(ary1, ary2, elementType, + if (elementTypeSupportEquals) null else ordering) } override def codeGen( From e5401e7029b9aba192c0e5a69607a57cc41692d8 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 1 Jun 2018 20:35:16 +0100 Subject: [PATCH 23/34] address review comment --- .../expressions/collectionOperations.scala | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 6b79011663eab..be2ee440ecc6e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3275,14 +3275,10 @@ object ArraySetLike { i += 1 } - val numBytes = IntegerType.defaultSize.toLong * array.length - val unsafeArraySizeInBytes = UnsafeArrayData.calculateHeaderPortionInBytes(array.length) + - ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) - // Since UnsafeArrayData.fromPrimitiveArray() uses long[], max elements * 8 bytes can be used - if (unsafeArraySizeInBytes <= Integer.MAX_VALUE * 8) { - UnsafeArrayData.fromPrimitiveArray(array) - } else { + if (useGenericArrayData(LongType.defaultSize, array.length)) { new GenericArrayData(array) + } else { + UnsafeArrayData.fromPrimitiveArray(array) } } @@ -3296,17 +3292,21 @@ object ArraySetLike { i += 1 } - val numBytes = LongType.defaultSize.toLong * array.length - val unsafeArraySizeInBytes = UnsafeArrayData.calculateHeaderPortionInBytes(array.length) + - ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) - // Since UnsafeArrayData.fromPrimitiveArray() uses long[], max elements * 8 bytes can be used - if (unsafeArraySizeInBytes <= Integer.MAX_VALUE * 8) { - UnsafeArrayData.fromPrimitiveArray(array) - } else { + if (useGenericArrayData(LongType.defaultSize, array.length)) { new GenericArrayData(array) + } else { + UnsafeArrayData.fromPrimitiveArray(array) } } + def useGenericArrayData(elementSize: Int, length: Int): Boolean = { + // Use the same calculation in UnsafeArrayData.fromPrimitiveArray() + val headerInBytes = UnsafeArrayData.calculateHeaderPortionInBytes(length) + val valueRegionInBytes = elementSize.toLong * length + val totalSizeInLongs = (headerInBytes + valueRegionInBytes + 7) / 8 + totalSizeInLongs > Integer.MAX_VALUE / 8 + } + def arrayUnion( array1: ArrayData, array2: ArrayData, From 3e21e48d26a9b169163fe1aab66b0a4d1dae120d Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sun, 10 Jun 2018 09:09:21 +0100 Subject: [PATCH 24/34] keep the order of input array elements --- python/pyspark/sql/functions.py | 2 +- .../expressions/collectionOperations.scala | 359 +++++++++--------- .../CollectionExpressionsSuite.scala | 6 +- .../org/apache/spark/sql/functions.scala | 1 - .../spark/sql/DataFrameFunctionsSuite.scala | 4 +- 5 files changed, 181 insertions(+), 191 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index a25d75f5776b4..00cd416a8fff9 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2018,7 +2018,7 @@ def array_distinct(col): def array_union(col1, col2): """ Collection function: returns an array of the elements in the union of col1 and col2, - without duplicates. The order of elements in the result is not determined. + without duplicates. :param col1: name of column containing array :param col2: name of column containing array diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index be2ee440ecc6e..b593630b65e92 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3263,42 +3263,6 @@ case class ArrayDistinct(child: Expression) } object ArraySetLike { - private val MAX_ARRAY_LENGTH: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH - - def toArrayDataInt(hs: OpenHashSet[Int]): ArrayData = { - val array = new Array[Int](hs.size) - var pos = hs.nextPos(0) - var i = 0 - while (pos != OpenHashSet.INVALID_POS) { - array(i) = hs.getValue(pos) - pos = hs.nextPos(pos + 1) - i += 1 - } - - if (useGenericArrayData(LongType.defaultSize, array.length)) { - new GenericArrayData(array) - } else { - UnsafeArrayData.fromPrimitiveArray(array) - } - } - - def toArrayDataLong(hs: OpenHashSet[Long]): ArrayData = { - val array = new Array[Long](hs.size) - var pos = hs.nextPos(0) - var i = 0 - while (pos != OpenHashSet.INVALID_POS) { - array(i) = hs.getValue(pos) - pos = hs.nextPos(pos + 1) - i += 1 - } - - if (useGenericArrayData(LongType.defaultSize, array.length)) { - new GenericArrayData(array) - } else { - UnsafeArrayData.fromPrimitiveArray(array) - } - } - def useGenericArrayData(elementSize: Int, length: Int): Boolean = { // Use the same calculation in UnsafeArrayData.fromPrimitiveArray() val headerInBytes = UnsafeArrayData.calculateHeaderPortionInBytes(length) @@ -3307,55 +3271,70 @@ object ArraySetLike { totalSizeInLongs > Integer.MAX_VALUE / 8 } - def arrayUnion( + def throwUnionLengthOverflowException(length: Int): Unit = { + throw new RuntimeException(s"Unsuccessful try to union arrays with ${length}" + + s"elements due to exceeding the array size limit " + + s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.") + } + + def evalUnionContainsNull( array1: ArrayData, array2: ArrayData, - et: DataType, + elementType: DataType, ordering: Ordering[Any]): ArrayData = { if (ordering == null) { - new GenericArrayData(array1.toObjectArray(et).union(array2.toObjectArray(et)) - .distinct.asInstanceOf[Array[Any]]) + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + val hs = new mutable.HashSet[Any] + Seq(array1, array2).foreach(array => { + var i = 0 + while (i < array.numElements()) { + val elem = array.get(i, elementType) + if (hs.add(elem)) { + if (arrayBuffer.length > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + throwUnionLengthOverflowException(arrayBuffer.length) + } + arrayBuffer += elem + } + i += 1 + } + }) + new GenericArrayData(arrayBuffer) } else { - val length = math.min(array1.numElements().toLong + array2.numElements().toLong, - ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) - val array = new Array[Any](length.toInt) - var pos = 0 - var hasNull = false - Seq(array1, array2).foreach(_.foreach(et, (_, v) => { + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + var alreadyIncludeNull = false + Seq(array1, array2).foreach(_.foreach(elementType, (_, elem) => { var found = false - if (v == null) { - if (hasNull) { + if (elem == null) { + if (alreadyIncludeNull) { found = true } else { - hasNull = true + alreadyIncludeNull = true } } else { + // check elem is already stored in arrayBuffer or not? var j = 0 - while (!found && j < pos) { - val va = array(j) - if (va != null && ordering.equiv(va, v)) { + while (!found && j < arrayBuffer.size) { + val va = arrayBuffer(j) + if (va != null && ordering.equiv(va, elem)) { found = true } j = j + 1 } } if (!found) { - if (pos > MAX_ARRAY_LENGTH) { - throw new RuntimeException(s"Unsuccessful try to union arrays with $pos" + - s" elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.") + if (arrayBuffer.length > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + throwUnionLengthOverflowException(arrayBuffer.length) } - array(pos) = v - pos = pos + 1 + arrayBuffer += elem } })) - new GenericArrayData(array.slice(0, pos)) + new GenericArrayData(arrayBuffer) } } } -abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast { - def arraySetLikeOpName: String +abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast { override def dataType: DataType = left.dataType override def checkInputDataTypes(): TypeCheckResult = { @@ -3368,7 +3347,7 @@ abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast { } } - private def cn = left.dataType.asInstanceOf[ArrayType].containsNull || + protected def cn = left.dataType.asInstanceOf[ArrayType].containsNull || right.dataType.asInstanceOf[ArrayType].containsNull @transient protected lazy val ordering: Ordering[Any] = @@ -3379,184 +3358,196 @@ abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast { case _: AtomicType => true case _ => false } +} - def intEval(ary: ArrayData, hs2: OpenHashSet[Int]): OpenHashSet[Int] - def longEval(ary: ArrayData, hs2: OpenHashSet[Long]): OpenHashSet[Long] - def genericEval(ary: ArrayData, hs2: OpenHashSet[Any]): OpenHashSet[Any] - def genericEvalContainsNull(ary1: ArrayData, ary2: ArrayData): ArrayData - def codeGen(ctx: CodegenContext, hs2: String, hs: String, len: String, getter: String, i: String, - postFix: String, newOpenHashSet: String): String +/** + * Returns an array of the elements in the union of x and y, without duplicates + */ +@ExpressionDescription( + usage = """ + _FUNC_(array1, array2) - Returns an array of the elements in the union of array1 and array2, + without duplicates. + """, + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5)); + array(1, 2, 3, 5) + """, + since = "2.4.0") +case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike { override def nullSafeEval(input1: Any, input2: Any): Any = { - val ary1 = input1.asInstanceOf[ArrayData] - val ary2 = input2.asInstanceOf[ArrayData] + val array1 = input1.asInstanceOf[ArrayData] + val array2 = input2.asInstanceOf[ArrayData] if (!cn) { elementType match { case IntegerType => // avoid boxing of primitive int array elements + // calculate result array size + val hsSize = new OpenHashSet[Int] + Seq(array1, array2).foreach(array => { + var i = 0 + while (i < array.numElements()) { + if (hsSize.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + ArraySetLike.throwUnionLengthOverflowException(hsSize.size) + } + hsSize.add(array.getInt(i)) + i += 1 + } + }) + // store elements into array + val resultArray = new Array[Int](hsSize.size) val hs = new OpenHashSet[Int] - var i = 0 - while (i < ary2.numElements()) { - hs.add(ary2.getInt(i)) - i += 1 + var pos = 0 + Seq(array1, array2).foreach(array => { + var i = 0 + while (i < array.numElements () ) { + val elem = array.getInt (i) + if (!hs.contains (elem) ) { + resultArray (pos) = elem + hs.add (elem) + pos += 1 + } + i += 1 + } + }) + if (ArraySetLike.useGenericArrayData(IntegerType.defaultSize, resultArray.length)) { + new GenericArrayData(resultArray) + } else { + UnsafeArrayData.fromPrimitiveArray(resultArray) } - ArraySetLike.toArrayDataInt(intEval(ary1, hs)) case LongType => // avoid boxing of primitive long array elements + // calculate result array size + val hsSize = new OpenHashSet[Long] + Seq(array1, array2).foreach(array => { + var i = 0 + while (i < array.numElements()) { + if (hsSize.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + ArraySetLike.throwUnionLengthOverflowException(hsSize.size) + } + hsSize.add(array.getLong(i)) + i += 1 + } + }) + // store elements into array + val resultArray = new Array[Long](hsSize.size) val hs = new OpenHashSet[Long] - var i = 0 - while (i < ary2.numElements()) { - hs.add(ary2.getLong(i)) - i += 1 + var pos = 0 + Seq(array1, array2).foreach(array => { + var i = 0 + while (i < array.numElements()) { + val elem = array.getLong(i) + if (!hs.contains(elem)) { + resultArray(pos) = elem + hs.add(elem) + pos += 1 + } + i += 1 + } + }) + if (ArraySetLike.useGenericArrayData(LongType.defaultSize, resultArray.length)) { + new GenericArrayData(resultArray) + } else { + UnsafeArrayData.fromPrimitiveArray(resultArray) } - ArraySetLike.toArrayDataLong(longEval(ary1, hs)) case _ => + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] val hs = new OpenHashSet[Any] - var i = 0 - while (i < ary2.numElements()) { - hs.add(ary2.get(i, elementType)) - i += 1 - } - new GenericArrayData(genericEval(ary1, hs).iterator.toArray) + Seq(array1, array2).foreach(array => { + var i = 0 + while (i < array.numElements()) { + val elem = array.get(i, elementType) + if (!hs.contains(elem)) { + if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.size) + } + arrayBuffer += elem + hs.add(elem) + } + i += 1 + } + }) + new GenericArrayData(arrayBuffer) } } else { - genericEvalContainsNull(ary1, ary2) + ArraySetLike.evalUnionContainsNull(array1, array2, elementType, + if (elementTypeSupportEquals) null else ordering) } } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val i = ctx.freshName("i") + val pos = ctx.freshName("pos") val value = ctx.freshName("value") val size = ctx.freshName("size") - val ary = ctx.freshName("ary") - val arraySetUtils = "org.apache.spark.sql.catalyst.expressions.ArraySetLike" val genericArrayData = classOf[GenericArrayData].getName - val openHashSet = classOf[OpenHashSet[_]].getName val (postFix, classTag, getter, setter, javaTypeName, arrayBuilder) = if (!cn) { val ptName = CodeGenerator.primitiveTypeName(elementType) elementType match { case ByteType | ShortType | IntegerType | LongType => - val uary = ctx.freshName("uary") + val unsafeArray = ctx.freshName("unsafeArray") (if (elementType == LongType) s"$$mcJ$$sp" else s"$$mcI$$sp", s"scala.reflect.ClassTag$$.MODULE$$.$ptName()", - s"get$ptName($i)", s"set$ptName($i, $value)", CodeGenerator.javaType(elementType), + s"get$ptName($i)", s"set$ptName($pos, $value)", CodeGenerator.javaType(elementType), s""" - |${ctx.createUnsafeArray(uary, size, elementType, s" $prettyName failed.")} - |${ev.value} = $uary; + |${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")} + |${ev.value} = $unsafeArray; """.stripMargin) case _ => val et = ctx.addReferenceObj("elementType", elementType) ("", s"scala.reflect.ClassTag$$.MODULE$$.Object()", - s"get($i, $et)", s"update($i, $value)", "Object", - s"${ev.value} = new $genericArrayData(new Object[$size]);") + s"get($i, $et)", s"update($pos, $value)", "Object", + s"${ev.value} = new $genericArrayData(new Object[$size]);") } } else { ("", "", "", "", "", "") } val hs = ctx.freshName("hs") - val hs2 = ctx.freshName("hs2") - val invalidPos = ctx.freshName("invalidPos") - val pos = ctx.freshName("pos") - nullSafeCodeGen(ctx, ev, (ary1, ary2) => { + nullSafeCodeGen(ctx, ev, (array1, array2) => { if (classTag != "") { - val secondLoop = codeGen(ctx, hs2, hs, s"$ary1.numElements()", s"$ary1.$getter", i, - postFix, s"new $openHashSet$postFix($classTag)") + val openHashSet = classOf[OpenHashSet[_]].getName s""" - |$openHashSet $hs2 = new $openHashSet$postFix($classTag); - |for (int $i = 0; $i < $ary2.numElements(); $i++) { - | $hs2.add$postFix($ary2.$getter); + |$openHashSet $hs = new $openHashSet$postFix($classTag); + |for (int $i = 0; $i < $array1.numElements(); $i++) { + | $hs.add$postFix($array1.$getter); + |} + |for (int $i = 0; $i < $array2.numElements(); $i++) { + | $hs.add$postFix($array2.$getter); |} - |$secondLoop |int $size = $hs.size(); |$arrayBuilder - |int $invalidPos = $openHashSet.INVALID_POS(); - |int $pos = $hs.nextPos(0); - |int $i = 0; - |while ($pos != $invalidPos) { - | $javaTypeName $value = ($javaTypeName) $hs.getValue$postFix($pos); - | ${ev.value}.$setter; - | $pos = $hs.nextPos($pos + 1); - | $i++; + |$hs = new $openHashSet$postFix($classTag); + |int $pos = 0; + |for (int $i = 0; $i < $array1.numElements(); $i++) { + | $javaTypeName $value = $array1.$getter; + | if (!$hs.contains($value)) { + | $hs.add$postFix($value); + | ${ev.value}.$setter; + | $pos++; + | } + |} + |for (int $i = 0; $i < $array2.numElements(); $i++) { + | $javaTypeName $value = $array2.$getter; + | if (!$hs.contains($value)) { + | $hs.add$postFix($value); + | ${ev.value}.$setter; + | $pos++; + | } |} """.stripMargin } else { - val et = ctx.addReferenceObj("elementTypeUtil", elementType) + val arraySetLike = "org.apache.spark.sql.catalyst.expressions.ArraySetLike" + val et = ctx.addReferenceObj("elementTypeUnion", elementType) val order = if (elementTypeSupportEquals) "null" - else ctx.addReferenceObj("orderingUtil", ordering) - s"${ev.value} = $arraySetUtils$$.MODULE$$.$arraySetLikeOpName($ary1, $ary2, $et, $order);" + else ctx.addReferenceObj("orderingUnion", ordering) + val method = "evalUnionContainsNull" + s"${ev.value} = $arraySetLike$$.MODULE$$.$method($array1, $array2, $et, $order);" } }) } -} - -/** - * Returns an array of the elements in the union of x and y, without duplicates - */ -@ExpressionDescription( - usage = """ - _FUNC_(array1, array2) - Returns an array of the elements in the union of array1 and array2, - without duplicates. The order of elements in the result is not determined. - """, - examples = """ - Examples: - > SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5)); - array(1, 2, 3, 5) - """, - since = "2.4.0") -case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike { - override def arraySetLikeOpName: String = "arrayUnion" - - override def intEval(ary: ArrayData, hs: OpenHashSet[Int]): OpenHashSet[Int] = { - var i = 0 - while (i < ary.numElements()) { - hs.add(ary.getInt(i)) - i += 1 - } - hs - } - - override def longEval(ary: ArrayData, hs: OpenHashSet[Long]): OpenHashSet[Long] = { - var i = 0 - while (i < ary.numElements()) { - hs.add(ary.getLong(i)) - i += 1 - } - hs - } - - override def genericEval(ary: ArrayData, hs: OpenHashSet[Any]): OpenHashSet[Any] = { - var i = 0 - while (i < ary.numElements()) { - hs.add(ary.get(i, elementType)) - i += 1 - } - hs - } - - override def genericEvalContainsNull(ary1: ArrayData, ary2: ArrayData): ArrayData = { - ArraySetLike.arrayUnion(ary1, ary2, elementType, - if (elementTypeSupportEquals) null else ordering) - } - - override def codeGen( - ctx: CodegenContext, - hs2: String, - hs: String, - len: String, - getter: String, - i: String, - postFix: String, - newOpenHashSet: String): String = { - val openHashSet = classOf[OpenHashSet[_]].getName - s""" - |for (int $i = 0; $i < $len; $i++) { - | $hs2.add$postFix($getter); - |} - |$openHashSet $hs = $hs2; - """.stripMargin - } override def prettyName: String = "array_union" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index d8aacc63c87bb..c94f606c35652 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1191,20 +1191,20 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val a30 = Literal.create(Seq(null, null), ArrayType(IntegerType)) val a31 = Literal.create(null, ArrayType(StringType)) - checkEvaluation(ArrayUnion(a00, a01), UnsafeArrayData.fromPrimitiveArray(Array(4, 1, 3, 2))) + checkEvaluation(ArrayUnion(a00, a01), UnsafeArrayData.fromPrimitiveArray(Array(1, 2, 3, 4))) checkEvaluation(ArrayUnion(a01, a02), Seq(4, 2, 1, 3)) checkEvaluation(ArrayUnion(a03, a04), Seq(1, 2, null, 4, 5, -5, -3, -1)) checkEvaluation(ArrayUnion(a03, a05), Seq(1, 2, null, 4, 5)) checkEvaluation( - ArrayUnion(a10, a11), UnsafeArrayData.fromPrimitiveArray(Array(4L, 1L, 3L, 2L))) + ArrayUnion(a10, a11), UnsafeArrayData.fromPrimitiveArray(Array(1L, 2L, 3L, 4L))) checkEvaluation(ArrayUnion(a11, a12), Seq(4L, 2L, 1L, 3L)) checkEvaluation(ArrayUnion(a13, a14), Seq(1L, 2L, null, 4L, 5L, -5L, -3L, -1L)) checkEvaluation(ArrayUnion(a13, a15), Seq(1L, 2L, null, 4L, 5L)) checkEvaluation(ArrayUnion(a20, a21), Seq("b", "a", "c", "d", "f")) checkEvaluation(ArrayUnion(a20, a22), Seq("b", "a", "c", null, "g")) - checkEvaluation(ArrayUnion(a23, a24), Seq("b", "c", "d", "a", "f")) + checkEvaluation(ArrayUnion(a23, a24), Seq("b", "a", "c", "d", "f")) checkEvaluation(ArrayUnion(a30, a30), Seq(null)) checkEvaluation(ArrayUnion(a20, a31), null) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 33223f0fe8cc0..9470364736b13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3199,7 +3199,6 @@ object functions { /** * Returns an array of the elements in the union of the given two arrays, without duplicates. - * The order of elements in the result is not determined * * @group collection_funcs * @since 2.4.0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 10180d75e9505..8cd9aad79b957 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1122,7 +1122,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { test("array_union functions") { val df1 = Seq((Array(1, 2, 3), Array(4, 2))).toDF("a", "b") - val ans1 = Row(Seq(4, 1, 3, 2)) + val ans1 = Row(Seq(1, 2, 3, 4)) checkAnswer(df1.select(array_union($"a", $"b")), ans1) checkAnswer(df1.selectExpr("array_union(a, b)"), ans1) @@ -1132,7 +1132,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(df2.selectExpr("array_union(a, b)"), ans2) val df3 = Seq((Array(1L, 2L, 3L), Array(4L, 2L))).toDF("a", "b") - val ans3 = Row(Seq(4L, 1L, 3L, 2L)) + val ans3 = Row(Seq(1L, 2L, 3L, 4L)) checkAnswer(df3.select(array_union($"a", $"b")), ans3) checkAnswer(df3.selectExpr("array_union(a, b)"), ans3) From 3c395063eab65a7dac862b9ac54483e445fd149c Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 20 Jun 2018 02:03:24 +0100 Subject: [PATCH 25/34] address review comments --- .../expressions/collectionOperations.scala | 176 +++++++++--------- .../CollectionExpressionsSuite.scala | 4 +- 2 files changed, 93 insertions(+), 87 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index b593630b65e92..a97179bfc665c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3272,65 +3272,10 @@ object ArraySetLike { } def throwUnionLengthOverflowException(length: Int): Unit = { - throw new RuntimeException(s"Unsuccessful try to union arrays with ${length}" + + throw new RuntimeException(s"Unsuccessful try to union arrays with $length " + s"elements due to exceeding the array size limit " + s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.") } - - def evalUnionContainsNull( - array1: ArrayData, - array2: ArrayData, - elementType: DataType, - ordering: Ordering[Any]): ArrayData = { - if (ordering == null) { - val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] - val hs = new mutable.HashSet[Any] - Seq(array1, array2).foreach(array => { - var i = 0 - while (i < array.numElements()) { - val elem = array.get(i, elementType) - if (hs.add(elem)) { - if (arrayBuffer.length > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { - throwUnionLengthOverflowException(arrayBuffer.length) - } - arrayBuffer += elem - } - i += 1 - } - }) - new GenericArrayData(arrayBuffer) - } else { - val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] - var alreadyIncludeNull = false - Seq(array1, array2).foreach(_.foreach(elementType, (_, elem) => { - var found = false - if (elem == null) { - if (alreadyIncludeNull) { - found = true - } else { - alreadyIncludeNull = true - } - } else { - // check elem is already stored in arrayBuffer or not? - var j = 0 - while (!found && j < arrayBuffer.size) { - val va = arrayBuffer(j) - if (va != null && ordering.equiv(va, elem)) { - found = true - } - j = j + 1 - } - } - if (!found) { - if (arrayBuffer.length > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { - throwUnionLengthOverflowException(arrayBuffer.length) - } - arrayBuffer += elem - } - })) - new GenericArrayData(arrayBuffer) - } - } } @@ -3380,7 +3325,7 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike val array1 = input1.asInstanceOf[ArrayData] val array2 = input2.asInstanceOf[ArrayData] - if (!cn) { + if (elementTypeSupportEquals && !cn) { elementType match { case IntegerType => // avoid boxing of primitive int array elements @@ -3396,17 +3341,17 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike i += 1 } }) - // store elements into array + // store elements into resultArray val resultArray = new Array[Int](hsSize.size) val hs = new OpenHashSet[Int] var pos = 0 Seq(array1, array2).foreach(array => { var i = 0 - while (i < array.numElements () ) { - val elem = array.getInt (i) - if (!hs.contains (elem) ) { - resultArray (pos) = elem - hs.add (elem) + while (i < array.numElements()) { + val elem = array.getInt(i) + if (!hs.contains(elem)) { + resultArray(pos) = elem + hs.add(elem) pos += 1 } i += 1 @@ -3431,7 +3376,7 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike i += 1 } }) - // store elements into array + // store elements into resultArray val resultArray = new Array[Long](hsSize.size) val hs = new OpenHashSet[Long] var pos = 0 @@ -3472,7 +3417,7 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike new GenericArrayData(arrayBuffer) } } else { - ArraySetLike.evalUnionContainsNull(array1, array2, elementType, + ArrayUnion.evalUnionContainsNull(array1, array2, elementType, if (elementTypeSupportEquals) null else ordering) } } @@ -3483,31 +3428,33 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike val value = ctx.freshName("value") val size = ctx.freshName("size") val genericArrayData = classOf[GenericArrayData].getName - val (postFix, classTag, getter, setter, javaTypeName, arrayBuilder) = if (!cn) { - val ptName = CodeGenerator.primitiveTypeName(elementType) - elementType match { - case ByteType | ShortType | IntegerType | LongType => - val unsafeArray = ctx.freshName("unsafeArray") - (if (elementType == LongType) s"$$mcJ$$sp" else s"$$mcI$$sp", - s"scala.reflect.ClassTag$$.MODULE$$.$ptName()", - s"get$ptName($i)", s"set$ptName($pos, $value)", CodeGenerator.javaType(elementType), - s""" - |${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")} - |${ev.value} = $unsafeArray; + val (postFix, classTag, getter, setter, javaTypeName, arrayBuilder) = + if (elementTypeSupportEquals && !cn) { + elementType match { + case ByteType | ShortType | IntegerType | LongType => + val ptName = CodeGenerator.primitiveTypeName(elementType) + val unsafeArray = ctx.freshName("unsafeArray") + (if (elementType == LongType) s"$$mcJ$$sp" else s"$$mcI$$sp", + s"scala.reflect.ClassTag$$.MODULE$$.$ptName()", + s"get$ptName($i)", s"set$ptName($pos, $value)", CodeGenerator.javaType(elementType), + s""" + |${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")} + |${ev.value} = $unsafeArray; """.stripMargin) - case _ => - val et = ctx.addReferenceObj("elementType", elementType) - ("", s"scala.reflect.ClassTag$$.MODULE$$.Object()", - s"get($i, $et)", s"update($pos, $value)", "Object", - s"${ev.value} = new $genericArrayData(new Object[$size]);") + case _ => + val et = ctx.addReferenceObj("elementType", elementType) + ("", s"scala.reflect.ClassTag$$.MODULE$$.Object()", + s"get($i, $et)", s"update($pos, $value)", "Object", + s"${ev.value} = new $genericArrayData(new Object[$size]);") + } + } else { + ("", "", "", "", "", "") } - } else { - ("", "", "", "", "", "") - } val hs = ctx.freshName("hs") nullSafeCodeGen(ctx, ev, (array1, array2) => { if (classTag != "") { + // Here, we ensure elementTypeSupportEquals && !array1.containsNull && !array2.containsNull val openHashSet = classOf[OpenHashSet[_]].getName s""" |$openHashSet $hs = new $openHashSet$postFix($classTag); @@ -3539,15 +3486,72 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike |} """.stripMargin } else { - val arraySetLike = "org.apache.spark.sql.catalyst.expressions.ArraySetLike" + val arrayUnion = classOf[ArrayUnion].getName val et = ctx.addReferenceObj("elementTypeUnion", elementType) val order = if (elementTypeSupportEquals) "null" else ctx.addReferenceObj("orderingUnion", ordering) val method = "evalUnionContainsNull" - s"${ev.value} = $arraySetLike$$.MODULE$$.$method($array1, $array2, $et, $order);" + s"${ev.value} = $arrayUnion$$.MODULE$$.$method($array1, $array2, $et, $order);" } }) } override def prettyName: String = "array_union" } + +object ArrayUnion { + def evalUnionContainsNull( + array1: ArrayData, + array2: ArrayData, + elementType: DataType, + ordering: Ordering[Any]): ArrayData = { + if (ordering == null) { + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + val hs = new mutable.HashSet[Any] + Seq(array1, array2).foreach(array => { + var i = 0 + while (i < array.numElements()) { + val elem = array.get(i, elementType) + if (hs.add(elem)) { + if (arrayBuffer.length > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.length) + } + arrayBuffer += elem + } + i += 1 + } + }) + new GenericArrayData(arrayBuffer) + } else { + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + var alreadyIncludeNull = false + Seq(array1, array2).foreach(_.foreach(elementType, (_, elem) => { + var found = false + if (elem == null) { + if (alreadyIncludeNull) { + found = true + } else { + alreadyIncludeNull = true + } + } else { + // check elem is already stored in arrayBuffer or not? + var j = 0 + while (!found && j < arrayBuffer.size) { + val va = arrayBuffer(j) + if (va != null && ordering.equiv(va, elem)) { + found = true + } + j = j + 1 + } + } + if (!found) { + if (arrayBuffer.length > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.length) + } + arrayBuffer += elem + } + })) + new GenericArrayData(arrayBuffer) + } + } +} \ No newline at end of file diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index c94f606c35652..c32c387aa1a16 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1194,6 +1194,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayUnion(a00, a01), UnsafeArrayData.fromPrimitiveArray(Array(1, 2, 3, 4))) checkEvaluation(ArrayUnion(a01, a02), Seq(4, 2, 1, 3)) checkEvaluation(ArrayUnion(a03, a04), Seq(1, 2, null, 4, 5, -5, -3, -1)) + checkEvaluation(ArrayUnion(a04, a03), Seq(-5, 4, -3, 2, -1, 1, null, 5)) checkEvaluation(ArrayUnion(a03, a05), Seq(1, 2, null, 4, 5)) checkEvaluation( @@ -1213,7 +1214,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2)), ArrayType(BinaryType)) val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](4, 3)), - ArrayType(BinaryType)) + ArrayType(BinaryType, false)) val b2 = Literal.create(Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](4, 3)), ArrayType(BinaryType)) val b3 = Literal.create(Seq[Array[Byte]]( @@ -1227,6 +1228,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper Seq(Array[Byte](5, 6), Array[Byte](1, 2), Array[Byte](2, 1), Array[Byte](4, 3))) checkEvaluation(ArrayUnion(b0, b2), Seq(Array[Byte](5, 6), Array[Byte](1, 2), Array[Byte](4, 3))) + checkEvaluation(ArrayUnion(b1, b1), Seq(Array[Byte](2, 1), Array[Byte](4, 3))) checkEvaluation(ArrayUnion(b2, b4), Seq(Array[Byte](1, 2), Array[Byte](4, 3), null)) checkEvaluation(ArrayUnion(b3, b0), Seq(Array[Byte](1, 2), Array[Byte](4, 3), Array[Byte](5, 6))) From 665474209991c093b6a08a693647e469d9f363c6 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 20 Jun 2018 03:30:39 +0100 Subject: [PATCH 26/34] fix scala style error --- .../spark/sql/catalyst/expressions/collectionOperations.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index a97179bfc665c..ac21ce12500dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3554,4 +3554,4 @@ object ArrayUnion { new GenericArrayData(arrayBuffer) } } -} \ No newline at end of file +} From be9f33149492f22320517448b185f07b6504a691 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 20 Jun 2018 11:28:23 +0100 Subject: [PATCH 27/34] address review comment refactoring to reduce # of lines --- .../catalyst/expressions/UnsafeArrayData.java | 7 + .../expressions/collectionOperations.scala | 196 +++++++++++------- .../CollectionExpressionsSuite.scala | 43 ++-- 3 files changed, 148 insertions(+), 98 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index 4dd2b7365652a..1a69757cb3396 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -471,6 +471,13 @@ private static UnsafeArrayData fromPrimitiveArray( return result; } + public static boolean useGenericArrayData(int elementSize, int length) { + final long headerInBytes = calculateHeaderPortionInBytes(length); + final long valueRegionInBytes = (long)elementSize * length; + final long totalSizeInLongs = (headerInBytes + valueRegionInBytes + 7) / 8; + return totalSizeInLongs > Integer.MAX_VALUE / 8; + } + public static UnsafeArrayData fromPrimitiveArray(boolean[] arr) { return fromPrimitiveArray(arr, Platform.BOOLEAN_ARRAY_OFFSET, arr.length, 1); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index ac21ce12500dc..afe8389440ef3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3292,9 +3292,6 @@ abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast { } } - protected def cn = left.dataType.asInstanceOf[ArrayType].containsNull || - right.dataType.asInstanceOf[ArrayType].containsNull - @transient protected lazy val ordering: Ordering[Any] = TypeUtils.getInterpretedOrdering(elementType) @@ -3320,12 +3317,70 @@ abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast { """, since = "2.4.0") case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike { + var hsInt: OpenHashSet[Int] = _ + var hsLong: OpenHashSet[Long] = _ + + def assignInt(array: ArrayData, idx: Int, resultArray: ArrayData, pos: Int): Boolean = { + val elem = array.getInt(idx) + if (!hsInt.contains(elem)) { + resultArray.setInt(pos, elem) + hsInt.add(elem) + true + } else { + false + } + } + + def assignLong(array: ArrayData, idx: Int, resultArray: ArrayData, pos: Int): Boolean = { + val elem = array.getLong(idx) + if (!hsLong.contains(elem)) { + resultArray.setLong(pos, elem) + hsLong.add(elem) + true + } else { + false + } + } + + def evalPrimitiveType( + array1: ArrayData, + array2: ArrayData, + size: Int, + resultArray: ArrayData, + isLongType: Boolean): ArrayData = { + // store elements into resultArray + var foundNullElement = false + var pos = 0 + Seq(array1, array2).foreach(array => { + var i = 0 + while (i < array.numElements()) { + if (array.isNullAt(i)) { + if (!foundNullElement) { + resultArray.setNullAt(pos) + pos += 1 + foundNullElement = true + } + } else { + val assigned = if (!isLongType) { + assignInt(array, i, resultArray, pos) + } else { + assignLong(array, i, resultArray, pos) + } + if (assigned) { + pos += 1 + } + } + i += 1 + } + }) + resultArray + } override def nullSafeEval(input1: Any, input2: Any): Any = { val array1 = input1.asInstanceOf[ArrayData] val array2 = input2.asInstanceOf[ArrayData] - if (elementTypeSupportEquals && !cn) { + if (elementTypeSupportEquals) { elementType match { case IntegerType => // avoid boxing of primitive int array elements @@ -3341,26 +3396,13 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike i += 1 } }) - // store elements into resultArray - val resultArray = new Array[Int](hsSize.size) - val hs = new OpenHashSet[Int] - var pos = 0 - Seq(array1, array2).foreach(array => { - var i = 0 - while (i < array.numElements()) { - val elem = array.getInt(i) - if (!hs.contains(elem)) { - resultArray(pos) = elem - hs.add(elem) - pos += 1 - } - i += 1 - } - }) - if (ArraySetLike.useGenericArrayData(IntegerType.defaultSize, resultArray.length)) { - new GenericArrayData(resultArray) + if (UnsafeArrayData.useGenericArrayData(IntegerType.defaultSize, hsSize.size)) { + ArrayUnion.evalUnionContainsNull(array1, array2, elementType, + if (elementTypeSupportEquals) null else ordering) } else { - UnsafeArrayData.fromPrimitiveArray(resultArray) + hsInt = new OpenHashSet[Int] + val resultArray = UnsafeArrayData.fromPrimitiveArray(new Array[Int](hsSize.size)) + evalPrimitiveType(array1, array2, hsSize.size, resultArray, false) } case LongType => // avoid boxing of primitive long array elements @@ -3376,40 +3418,35 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike i += 1 } }) - // store elements into resultArray - val resultArray = new Array[Long](hsSize.size) - val hs = new OpenHashSet[Long] - var pos = 0 - Seq(array1, array2).foreach(array => { - var i = 0 - while (i < array.numElements()) { - val elem = array.getLong(i) - if (!hs.contains(elem)) { - resultArray(pos) = elem - hs.add(elem) - pos += 1 - } - i += 1 - } - }) - if (ArraySetLike.useGenericArrayData(LongType.defaultSize, resultArray.length)) { - new GenericArrayData(resultArray) + if (UnsafeArrayData.useGenericArrayData(IntegerType.defaultSize, hsSize.size)) { + ArrayUnion.evalUnionContainsNull(array1, array2, elementType, + if (elementTypeSupportEquals) null else ordering) } else { - UnsafeArrayData.fromPrimitiveArray(resultArray) + hsLong = new OpenHashSet[Long] + val resultArray = UnsafeArrayData.fromPrimitiveArray(new Array[Long](hsSize.size)) + evalPrimitiveType(array1, array2, hsSize.size, resultArray, true) } case _ => val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] val hs = new OpenHashSet[Any] + var foundNullElement = false Seq(array1, array2).foreach(array => { var i = 0 while (i < array.numElements()) { - val elem = array.get(i, elementType) - if (!hs.contains(elem)) { - if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { - ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.size) + if (array.isNullAt(i)) { + if (!foundNullElement) { + arrayBuffer += null + foundNullElement = true + } + } else { + val elem = array.get(i, elementType) + if (!hs.contains(elem)) { + if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.size) + } + arrayBuffer += elem + hs.add(elem) } - arrayBuffer += elem - hs.add(elem) } i += 1 } @@ -3427,9 +3464,8 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike val pos = ctx.freshName("pos") val value = ctx.freshName("value") val size = ctx.freshName("size") - val genericArrayData = classOf[GenericArrayData].getName val (postFix, classTag, getter, setter, javaTypeName, arrayBuilder) = - if (elementTypeSupportEquals && !cn) { + if (elementTypeSupportEquals) { elementType match { case ByteType | ShortType | IntegerType | LongType => val ptName = CodeGenerator.primitiveTypeName(elementType) @@ -3442,6 +3478,7 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike |${ev.value} = $unsafeArray; """.stripMargin) case _ => + val genericArrayData = classOf[GenericArrayData].getName val et = ctx.addReferenceObj("elementType", elementType) ("", s"scala.reflect.ClassTag$$.MODULE$$.Object()", s"get($i, $et)", s"update($pos, $value)", "Object", @@ -3451,37 +3488,51 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike ("", "", "", "", "", "") } - val hs = ctx.freshName("hs") nullSafeCodeGen(ctx, ev, (array1, array2) => { if (classTag != "") { // Here, we ensure elementTypeSupportEquals && !array1.containsNull && !array2.containsNull + val foundNullElement = ctx.freshName("foundNullElement") val openHashSet = classOf[OpenHashSet[_]].getName + val hs = ctx.freshName("hs") + val arrayData = classOf[ArrayData].getName + val arrays = ctx.freshName("arrays") + val array = ctx.freshName("array") + val arrayDataIdx = ctx.freshName("arrayDataIdx") s""" |$openHashSet $hs = new $openHashSet$postFix($classTag); - |for (int $i = 0; $i < $array1.numElements(); $i++) { - | $hs.add$postFix($array1.$getter); - |} - |for (int $i = 0; $i < $array2.numElements(); $i++) { - | $hs.add$postFix($array2.$getter); + |boolean $foundNullElement = false; + |$arrayData[] $arrays = new $arrayData[]{$array1, $array2}; + |for (int $arrayDataIdx = 0; $arrayDataIdx < 2; $arrayDataIdx++) { + | $arrayData $array = $arrays[$arrayDataIdx]; + | for (int $i = 0; $i < $array.numElements(); $i++) { + | if ($array.isNullAt($i)) { + | $foundNullElement = true; + | } else { + | $hs.add$postFix($array.$getter); + | } + | } |} - |int $size = $hs.size(); + |int $size = $hs.size() + ($foundNullElement ? 1 : 0); |$arrayBuilder |$hs = new $openHashSet$postFix($classTag); + |$foundNullElement = false; |int $pos = 0; - |for (int $i = 0; $i < $array1.numElements(); $i++) { - | $javaTypeName $value = $array1.$getter; - | if (!$hs.contains($value)) { - | $hs.add$postFix($value); - | ${ev.value}.$setter; - | $pos++; - | } - |} - |for (int $i = 0; $i < $array2.numElements(); $i++) { - | $javaTypeName $value = $array2.$getter; - | if (!$hs.contains($value)) { - | $hs.add$postFix($value); - | ${ev.value}.$setter; - | $pos++; + |for (int $arrayDataIdx = 0; $arrayDataIdx < 2; $arrayDataIdx++) { + | $arrayData $array = $arrays[$arrayDataIdx]; + | for (int $i = 0; $i < $array.numElements(); $i++) { + | if ($array.isNullAt($i)) { + | if (!$foundNullElement) { + | ${ev.value}.setNullAt($pos++); + | $foundNullElement = true; + | } + | } else { + | $javaTypeName $value = $array.$getter; + | if (!$hs.contains($value)) { + | $hs.add$postFix($value); + | ${ev.value}.$setter; + | $pos++; + | } + | } | } |} """.stripMargin @@ -3505,8 +3556,8 @@ object ArrayUnion { array2: ArrayData, elementType: DataType, ordering: Ordering[Any]): ArrayData = { + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] if (ordering == null) { - val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] val hs = new mutable.HashSet[Any] Seq(array1, array2).foreach(array => { var i = 0 @@ -3523,7 +3574,6 @@ object ArrayUnion { }) new GenericArrayData(arrayBuffer) } else { - val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] var alreadyIncludeNull = false Seq(array1, array2).foreach(_.foreach(elementType, (_, elem) => { var found = false diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index c32c387aa1a16..04269a6915e87 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1168,44 +1168,38 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper } test("Array Union") { - val a00 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, false)) - val a01 = Literal.create(Seq(4, 2), ArrayType(IntegerType, false)) - val a02 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) - val a03 = Literal.create(Seq(1, 2, null, 4, 5), ArrayType(IntegerType)) - val a04 = Literal.create(Seq(-5, 4, -3, 2, -1), ArrayType(IntegerType)) - val a05 = Literal.create(Seq.empty[Int], ArrayType(IntegerType)) - - val a10 = Literal.create(Seq(1L, 2L, 3L), ArrayType(LongType, false)) - val a11 = Literal.create(Seq(4L, 2L), ArrayType(LongType, false)) - val a12 = Literal.create(Seq(1L, 2L, 3L), ArrayType(LongType)) - val a13 = Literal.create(Seq(1L, 2L, null, 4L, 5L), ArrayType(LongType)) - val a14 = Literal.create(Seq(-5L, 4L, -3L, 2L, -1L), ArrayType(LongType)) - val a15 = Literal.create(Seq.empty[Long], ArrayType(LongType)) + val a00 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) + val a01 = Literal.create(Seq(4, 2), ArrayType(IntegerType)) + val a02 = Literal.create(Seq(1, 2, null, 4, 5), ArrayType(IntegerType)) + val a03 = Literal.create(Seq(-5, 4, -3, 2, -1), ArrayType(IntegerType)) + val a04 = Literal.create(Seq.empty[Int], ArrayType(IntegerType)) + + val a10 = Literal.create(Seq(1L, 2L, 3L), ArrayType(LongType)) + val a11 = Literal.create(Seq(4L, 2L), ArrayType(LongType)) + val a12 = Literal.create(Seq(1L, 2L, null, 4L, 5L), ArrayType(LongType)) + val a13 = Literal.create(Seq(-5L, 4L, -3L, 2L, -1L), ArrayType(LongType)) + val a14 = Literal.create(Seq.empty[Long], ArrayType(LongType)) val a20 = Literal.create(Seq("b", "a", "c"), ArrayType(StringType)) val a21 = Literal.create(Seq("c", "d", "a", "f"), ArrayType(StringType)) val a22 = Literal.create(Seq("b", null, "a", "g"), ArrayType(StringType)) - val a23 = Literal.create(Seq("b", "a", "c"), ArrayType(StringType, false)) - val a24 = Literal.create(Seq("c", "d", "a", "f"), ArrayType(StringType, false)) val a30 = Literal.create(Seq(null, null), ArrayType(IntegerType)) val a31 = Literal.create(null, ArrayType(StringType)) checkEvaluation(ArrayUnion(a00, a01), UnsafeArrayData.fromPrimitiveArray(Array(1, 2, 3, 4))) - checkEvaluation(ArrayUnion(a01, a02), Seq(4, 2, 1, 3)) - checkEvaluation(ArrayUnion(a03, a04), Seq(1, 2, null, 4, 5, -5, -3, -1)) - checkEvaluation(ArrayUnion(a04, a03), Seq(-5, 4, -3, 2, -1, 1, null, 5)) - checkEvaluation(ArrayUnion(a03, a05), Seq(1, 2, null, 4, 5)) + checkEvaluation(ArrayUnion(a02, a03), Seq(1, 2, null, 4, 5, -5, -3, -1)) + checkEvaluation(ArrayUnion(a03, a02), Seq(-5, 4, -3, 2, -1, 1, null, 5)) + checkEvaluation(ArrayUnion(a02, a04), Seq(1, 2, null, 4, 5)) checkEvaluation( ArrayUnion(a10, a11), UnsafeArrayData.fromPrimitiveArray(Array(1L, 2L, 3L, 4L))) - checkEvaluation(ArrayUnion(a11, a12), Seq(4L, 2L, 1L, 3L)) - checkEvaluation(ArrayUnion(a13, a14), Seq(1L, 2L, null, 4L, 5L, -5L, -3L, -1L)) - checkEvaluation(ArrayUnion(a13, a15), Seq(1L, 2L, null, 4L, 5L)) + checkEvaluation(ArrayUnion(a12, a13), Seq(1L, 2L, null, 4L, 5L, -5L, -3L, -1L)) + checkEvaluation(ArrayUnion(a13, a12), Seq(-5L, 4L, -3L, 2L, -1L, 1L, null, 5L)) + checkEvaluation(ArrayUnion(a12, a14), Seq(1L, 2L, null, 4L, 5L)) checkEvaluation(ArrayUnion(a20, a21), Seq("b", "a", "c", "d", "f")) checkEvaluation(ArrayUnion(a20, a22), Seq("b", "a", "c", null, "g")) - checkEvaluation(ArrayUnion(a23, a24), Seq("b", "a", "c", "d", "f")) checkEvaluation(ArrayUnion(a30, a30), Seq(null)) checkEvaluation(ArrayUnion(a20, a31), null) @@ -1214,7 +1208,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2)), ArrayType(BinaryType)) val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](4, 3)), - ArrayType(BinaryType, false)) + ArrayType(BinaryType)) val b2 = Literal.create(Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](4, 3)), ArrayType(BinaryType)) val b3 = Literal.create(Seq[Array[Byte]]( @@ -1228,7 +1222,6 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper Seq(Array[Byte](5, 6), Array[Byte](1, 2), Array[Byte](2, 1), Array[Byte](4, 3))) checkEvaluation(ArrayUnion(b0, b2), Seq(Array[Byte](5, 6), Array[Byte](1, 2), Array[Byte](4, 3))) - checkEvaluation(ArrayUnion(b1, b1), Seq(Array[Byte](2, 1), Array[Byte](4, 3))) checkEvaluation(ArrayUnion(b2, b4), Seq(Array[Byte](1, 2), Array[Byte](4, 3), null)) checkEvaluation(ArrayUnion(b3, b0), Seq(Array[Byte](1, 2), Array[Byte](4, 3), Array[Byte](5, 6))) From 90e84b3320e4d9f3a4f4487068c222b10d1d54dc Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 22 Jun 2018 03:19:21 +0100 Subject: [PATCH 28/34] address review comments --- .../catalyst/expressions/CollectionExpressionsSuite.scala | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 04269a6915e87..d8219b2511067 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1173,6 +1173,10 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val a02 = Literal.create(Seq(1, 2, null, 4, 5), ArrayType(IntegerType)) val a03 = Literal.create(Seq(-5, 4, -3, 2, -1), ArrayType(IntegerType)) val a04 = Literal.create(Seq.empty[Int], ArrayType(IntegerType)) + val a05 = Literal.create(Seq[Byte](1, 2, 3), ArrayType(ByteType)) + val a06 = Literal.create(Seq[Byte](4, 2), ArrayType(ByteType)) + val a07 = Literal.create(Seq[Short](1, 2, 3), ArrayType(ShortType)) + val a08 = Literal.create(Seq[Short](4, 2), ArrayType(ShortType)) val a10 = Literal.create(Seq(1L, 2L, 3L), ArrayType(LongType)) val a11 = Literal.create(Seq(4L, 2L), ArrayType(LongType)) @@ -1191,6 +1195,10 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayUnion(a02, a03), Seq(1, 2, null, 4, 5, -5, -3, -1)) checkEvaluation(ArrayUnion(a03, a02), Seq(-5, 4, -3, 2, -1, 1, null, 5)) checkEvaluation(ArrayUnion(a02, a04), Seq(1, 2, null, 4, 5)) + checkEvaluation( + ArrayUnion(a05, a06), UnsafeArrayData.fromPrimitiveArray(Array[Byte](1, 2, 3, 4))) + checkEvaluation( + ArrayUnion(a07, a08), UnsafeArrayData.fromPrimitiveArray(Array[Short](1, 2, 3, 4))) checkEvaluation( ArrayUnion(a10, a11), UnsafeArrayData.fromPrimitiveArray(Array(1L, 2L, 3L, 4L))) From 6f721f0f53361743e86980c8a45807d14a303666 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 22 Jun 2018 07:29:49 +0100 Subject: [PATCH 29/34] address review comments --- .../catalyst/expressions/UnsafeArrayData.java | 12 +- .../expressions/collectionOperations.scala | 154 +++++++++--------- 2 files changed, 82 insertions(+), 84 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index 1a69757cb3396..a6f61163ca0aa 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -450,7 +450,7 @@ public double[] toDoubleArray() { return values; } - private static UnsafeArrayData fromPrimitiveArray( + public static UnsafeArrayData fromPrimitiveArray( Object arr, int offset, int length, int elementSize) { final long headerInBytes = calculateHeaderPortionInBytes(length); final long valueRegionInBytes = (long)elementSize * length; @@ -463,14 +463,20 @@ private static UnsafeArrayData fromPrimitiveArray( final long[] data = new long[(int)totalSizeInLongs]; Platform.putLong(data, Platform.LONG_ARRAY_OFFSET, length); - Platform.copyMemory(arr, offset, data, - Platform.LONG_ARRAY_OFFSET + headerInBytes, valueRegionInBytes); + if (arr != null) { + Platform.copyMemory(arr, offset, data, + Platform.LONG_ARRAY_OFFSET + headerInBytes, valueRegionInBytes); + } UnsafeArrayData result = new UnsafeArrayData(); result.pointTo(data, Platform.LONG_ARRAY_OFFSET, (int)totalSizeInLongs * 8); return result; } + public static UnsafeArrayData forPrimitiveArray(int offset, int length, int elementSize) { + return fromPrimitiveArray(null, offset, length, elementSize); + } + public static boolean useGenericArrayData(int elementSize, int length) { final long headerInBytes = calculateHeaderPortionInBytes(length); final long valueRegionInBytes = (long)elementSize * length; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index afe8389440ef3..a94f823cb4aa1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3263,14 +3263,6 @@ case class ArrayDistinct(child: Expression) } object ArraySetLike { - def useGenericArrayData(elementSize: Int, length: Int): Boolean = { - // Use the same calculation in UnsafeArrayData.fromPrimitiveArray() - val headerInBytes = UnsafeArrayData.calculateHeaderPortionInBytes(length) - val valueRegionInBytes = elementSize.toLong * length - val totalSizeInLongs = (headerInBytes + valueRegionInBytes + 7) / 8 - totalSizeInLongs > Integer.MAX_VALUE / 8 - } - def throwUnionLengthOverflowException(length: Int): Unit = { throw new RuntimeException(s"Unsuccessful try to union arrays with $length " + s"elements due to exceeding the array size limit " + @@ -3342,7 +3334,7 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike } } - def evalPrimitiveType( + def evalIntLongPrimitiveType( array1: ArrayData, array2: ArrayData, size: Int, @@ -3386,46 +3378,64 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike // avoid boxing of primitive int array elements // calculate result array size val hsSize = new OpenHashSet[Int] + var nullElementSize = 0 Seq(array1, array2).foreach(array => { var i = 0 while (i < array.numElements()) { - if (hsSize.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + if (hsSize.size + nullElementSize > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { ArraySetLike.throwUnionLengthOverflowException(hsSize.size) } - hsSize.add(array.getInt(i)) + if (array.isNullAt(i)) { + if (nullElementSize == 0) { + nullElementSize = 1 + } + } else { + hsSize.add(array.getInt(i)) + } i += 1 } }) - if (UnsafeArrayData.useGenericArrayData(IntegerType.defaultSize, hsSize.size)) { - ArrayUnion.evalUnionContainsNull(array1, array2, elementType, - if (elementTypeSupportEquals) null else ordering) + val elements = hsSize.size + nullElementSize + hsInt = new OpenHashSet[Int] + val resultArray = if (UnsafeArrayData.useGenericArrayData( + IntegerType.defaultSize, elements)) { + new GenericArrayData(new Array[Any](elements)) } else { - hsInt = new OpenHashSet[Int] - val resultArray = UnsafeArrayData.fromPrimitiveArray(new Array[Int](hsSize.size)) - evalPrimitiveType(array1, array2, hsSize.size, resultArray, false) + UnsafeArrayData.forPrimitiveArray( + Platform.INT_ARRAY_OFFSET, elements, IntegerType.defaultSize); } + evalIntLongPrimitiveType(array1, array2, elements, resultArray, false) case LongType => // avoid boxing of primitive long array elements // calculate result array size val hsSize = new OpenHashSet[Long] + var nullElementSize = 0 Seq(array1, array2).foreach(array => { var i = 0 while (i < array.numElements()) { - if (hsSize.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + if (hsSize.size + nullElementSize> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { ArraySetLike.throwUnionLengthOverflowException(hsSize.size) } - hsSize.add(array.getLong(i)) + if (array.isNullAt(i)) { + if (nullElementSize == 0) { + nullElementSize = 1 + } + } else { + hsSize.add(array.getLong(i)) + } i += 1 } }) - if (UnsafeArrayData.useGenericArrayData(IntegerType.defaultSize, hsSize.size)) { - ArrayUnion.evalUnionContainsNull(array1, array2, elementType, - if (elementTypeSupportEquals) null else ordering) + val elements = hsSize.size + nullElementSize + hsLong = new OpenHashSet[Long] + val resultArray = if (UnsafeArrayData.useGenericArrayData( + LongType.defaultSize, elements)) { + new GenericArrayData(new Array[Any](elements)) } else { - hsLong = new OpenHashSet[Long] - val resultArray = UnsafeArrayData.fromPrimitiveArray(new Array[Long](hsSize.size)) - evalPrimitiveType(array1, array2, hsSize.size, resultArray, true) + UnsafeArrayData.forPrimitiveArray( + Platform.LONG_ARRAY_OFFSET, elements, LongType.defaultSize); } + evalIntLongPrimitiveType(array1, array2, elements, resultArray, true) case _ => val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] val hs = new OpenHashSet[Any] @@ -3454,8 +3464,7 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike new GenericArrayData(arrayBuffer) } } else { - ArrayUnion.evalUnionContainsNull(array1, array2, elementType, - if (elementTypeSupportEquals) null else ordering) + ArrayUnion.unionOrdering(array1, array2, elementType, ordering) } } @@ -3464,35 +3473,37 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike val pos = ctx.freshName("pos") val value = ctx.freshName("value") val size = ctx.freshName("size") - val (postFix, classTag, getter, setter, javaTypeName, arrayBuilder) = + val (postFix, openHashElementType, getter, setter, javaTypeName, castOp, arrayBuilder) = if (elementTypeSupportEquals) { elementType match { case ByteType | ShortType | IntegerType | LongType => val ptName = CodeGenerator.primitiveTypeName(elementType) val unsafeArray = ctx.freshName("unsafeArray") (if (elementType == LongType) s"$$mcJ$$sp" else s"$$mcI$$sp", - s"scala.reflect.ClassTag$$.MODULE$$.$ptName()", + if (elementType == LongType) "Long" else "Int", s"get$ptName($i)", s"set$ptName($pos, $value)", CodeGenerator.javaType(elementType), + if (elementType == LongType) "(long)" else "(int)", s""" |${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")} |${ev.value} = $unsafeArray; - """.stripMargin) + """.stripMargin) case _ => val genericArrayData = classOf[GenericArrayData].getName val et = ctx.addReferenceObj("elementType", elementType) - ("", s"scala.reflect.ClassTag$$.MODULE$$.Object()", - s"get($i, $et)", s"update($pos, $value)", "Object", + ("", "Object", + s"get($i, $et)", s"update($pos, $value)", "Object", "", s"${ev.value} = new $genericArrayData(new Object[$size]);") } } else { - ("", "", "", "", "", "") + ("", "", "", "", "", "", "") } nullSafeCodeGen(ctx, ev, (array1, array2) => { - if (classTag != "") { - // Here, we ensure elementTypeSupportEquals && !array1.containsNull && !array2.containsNull + if (openHashElementType != "") { + // Here, we ensure elementTypeSupportEquals is true val foundNullElement = ctx.freshName("foundNullElement") val openHashSet = classOf[OpenHashSet[_]].getName + val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$openHashElementType()" val hs = ctx.freshName("hs") val arrayData = classOf[ArrayData].getName val arrays = ctx.freshName("arrays") @@ -3527,7 +3538,7 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike | } | } else { | $javaTypeName $value = $array.$getter; - | if (!$hs.contains($value)) { + | if (!$hs.contains($castOp $value)) { | $hs.add$postFix($value); | ${ev.value}.$setter; | $pos++; @@ -3539,9 +3550,8 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike } else { val arrayUnion = classOf[ArrayUnion].getName val et = ctx.addReferenceObj("elementTypeUnion", elementType) - val order = if (elementTypeSupportEquals) "null" - else ctx.addReferenceObj("orderingUnion", ordering) - val method = "evalUnionContainsNull" + val order = ctx.addReferenceObj("orderingUnion", ordering) + val method = "unionOrdering" s"${ev.value} = $arrayUnion$$.MODULE$$.$method($array1, $array2, $et, $order);" } }) @@ -3551,57 +3561,39 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike } object ArrayUnion { - def evalUnionContainsNull( + def unionOrdering( array1: ArrayData, array2: ArrayData, elementType: DataType, ordering: Ordering[Any]): ArrayData = { val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] - if (ordering == null) { - val hs = new mutable.HashSet[Any] - Seq(array1, array2).foreach(array => { - var i = 0 - while (i < array.numElements()) { - val elem = array.get(i, elementType) - if (hs.add(elem)) { - if (arrayBuffer.length > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { - ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.length) - } - arrayBuffer += elem - } - i += 1 + var alreadyIncludeNull = false + Seq(array1, array2).foreach(_.foreach(elementType, (_, elem) => { + var found = false + if (elem == null) { + if (alreadyIncludeNull) { + found = true + } else { + alreadyIncludeNull = true } - }) - new GenericArrayData(arrayBuffer) - } else { - var alreadyIncludeNull = false - Seq(array1, array2).foreach(_.foreach(elementType, (_, elem) => { - var found = false - if (elem == null) { - if (alreadyIncludeNull) { + } else { + // check elem is already stored in arrayBuffer or not? + var j = 0 + while (!found && j < arrayBuffer.size) { + val va = arrayBuffer(j) + if (va != null && ordering.equiv(va, elem)) { found = true - } else { - alreadyIncludeNull = true - } - } else { - // check elem is already stored in arrayBuffer or not? - var j = 0 - while (!found && j < arrayBuffer.size) { - val va = arrayBuffer(j) - if (va != null && ordering.equiv(va, elem)) { - found = true - } - j = j + 1 } + j = j + 1 } - if (!found) { - if (arrayBuffer.length > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { - ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.length) - } - arrayBuffer += elem + } + if (!found) { + if (arrayBuffer.length > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.length) } - })) - new GenericArrayData(arrayBuffer) - } + arrayBuffer += elem + } + })) + new GenericArrayData(arrayBuffer) } } From 0c0d3baebd06403d8e8c339575e7701133ebc324 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sun, 8 Jul 2018 07:05:23 +0100 Subject: [PATCH 30/34] address review comments --- .../expressions/collectionOperations.scala | 25 +++++++---- .../CollectionExpressionsSuite.scala | 43 +++++++++++-------- 2 files changed, 40 insertions(+), 28 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index a94f823cb4aa1..efbe97b219399 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3272,7 +3272,14 @@ object ArraySetLike { abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast { - override def dataType: DataType = left.dataType + override def dataType: DataType = { + val dataTypes = children.map(_.dataType) + dataTypes.headOption.map { + case ArrayType(et, _) => + ArrayType(et, dataTypes.exists(_.asInstanceOf[ArrayType].containsNull)) + case dt => dt + }.getOrElse(StringType) + } override def checkInputDataTypes(): TypeCheckResult = { val typeCheckResult = super.checkInputDataTypes() @@ -3379,7 +3386,7 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike // calculate result array size val hsSize = new OpenHashSet[Int] var nullElementSize = 0 - Seq(array1, array2).foreach(array => { + Seq(array1, array2).foreach { array => var i = 0 while (i < array.numElements()) { if (hsSize.size + nullElementSize > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { @@ -3394,11 +3401,11 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike } i += 1 } - }) + } val elements = hsSize.size + nullElementSize hsInt = new OpenHashSet[Int] val resultArray = if (UnsafeArrayData.useGenericArrayData( - IntegerType.defaultSize, elements)) { + IntegerType.defaultSize, elements)) { new GenericArrayData(new Array[Any](elements)) } else { UnsafeArrayData.forPrimitiveArray( @@ -3410,10 +3417,10 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike // calculate result array size val hsSize = new OpenHashSet[Long] var nullElementSize = 0 - Seq(array1, array2).foreach(array => { + Seq(array1, array2).foreach { array => var i = 0 while (i < array.numElements()) { - if (hsSize.size + nullElementSize> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + if (hsSize.size + nullElementSize > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { ArraySetLike.throwUnionLengthOverflowException(hsSize.size) } if (array.isNullAt(i)) { @@ -3425,7 +3432,7 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike } i += 1 } - }) + } val elements = hsSize.size + nullElementSize hsLong = new OpenHashSet[Long] val resultArray = if (UnsafeArrayData.useGenericArrayData( @@ -3440,7 +3447,7 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] val hs = new OpenHashSet[Any] var foundNullElement = false - Seq(array1, array2).foreach(array => { + Seq(array1, array2).foreach { array => var i = 0 while (i < array.numElements()) { if (array.isNullAt(i)) { @@ -3460,7 +3467,7 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike } i += 1 } - }) + } new GenericArrayData(arrayBuffer) } } else { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index d8219b2511067..42ce7dc6fa0f5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1168,25 +1168,25 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper } test("Array Union") { - val a00 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) - val a01 = Literal.create(Seq(4, 2), ArrayType(IntegerType)) - val a02 = Literal.create(Seq(1, 2, null, 4, 5), ArrayType(IntegerType)) - val a03 = Literal.create(Seq(-5, 4, -3, 2, -1), ArrayType(IntegerType)) - val a04 = Literal.create(Seq.empty[Int], ArrayType(IntegerType)) - val a05 = Literal.create(Seq[Byte](1, 2, 3), ArrayType(ByteType)) - val a06 = Literal.create(Seq[Byte](4, 2), ArrayType(ByteType)) - val a07 = Literal.create(Seq[Short](1, 2, 3), ArrayType(ShortType)) - val a08 = Literal.create(Seq[Short](4, 2), ArrayType(ShortType)) - - val a10 = Literal.create(Seq(1L, 2L, 3L), ArrayType(LongType)) - val a11 = Literal.create(Seq(4L, 2L), ArrayType(LongType)) - val a12 = Literal.create(Seq(1L, 2L, null, 4L, 5L), ArrayType(LongType)) - val a13 = Literal.create(Seq(-5L, 4L, -3L, 2L, -1L), ArrayType(LongType)) - val a14 = Literal.create(Seq.empty[Long], ArrayType(LongType)) - - val a20 = Literal.create(Seq("b", "a", "c"), ArrayType(StringType)) - val a21 = Literal.create(Seq("c", "d", "a", "f"), ArrayType(StringType)) - val a22 = Literal.create(Seq("b", null, "a", "g"), ArrayType(StringType)) + val a00 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) + val a01 = Literal.create(Seq(4, 2), ArrayType(IntegerType, containsNull = false)) + val a02 = Literal.create(Seq(1, 2, null, 4, 5), ArrayType(IntegerType, containsNull = true)) + val a03 = Literal.create(Seq(-5, 4, -3, 2, -1), ArrayType(IntegerType, containsNull = false)) + val a04 = Literal.create(Seq.empty[Int], ArrayType(IntegerType, containsNull = false)) + val a05 = Literal.create(Seq[Byte](1, 2, 3), ArrayType(ByteType, containsNull = false)) + val a06 = Literal.create(Seq[Byte](4, 2), ArrayType(ByteType, containsNull = false)) + val a07 = Literal.create(Seq[Short](1, 2, 3), ArrayType(ShortType, containsNull = false)) + val a08 = Literal.create(Seq[Short](4, 2), ArrayType(ShortType, containsNull = false)) + + val a10 = Literal.create(Seq(1L, 2L, 3L), ArrayType(LongType, containsNull = false)) + val a11 = Literal.create(Seq(4L, 2L), ArrayType(LongType, containsNull = false)) + val a12 = Literal.create(Seq(1L, 2L, null, 4L, 5L), ArrayType(LongType, containsNull = true)) + val a13 = Literal.create(Seq(-5L, 4L, -3L, 2L, -1L), ArrayType(LongType, containsNull = false)) + val a14 = Literal.create(Seq.empty[Long], ArrayType(LongType, containsNull = false)) + + val a20 = Literal.create(Seq("b", "a", "c"), ArrayType(StringType, containsNull = false)) + val a21 = Literal.create(Seq("c", "d", "a", "f"), ArrayType(StringType, containsNull = false)) + val a22 = Literal.create(Seq("b", null, "a", "g"), ArrayType(StringType, containsNull = true)) val a30 = Literal.create(Seq(null, null), ArrayType(IntegerType)) val a31 = Literal.create(null, ArrayType(StringType)) @@ -1244,5 +1244,10 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper ArrayType(ArrayType(IntegerType))) checkEvaluation(ArrayUnion(aa0, aa1), Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4), Seq[Int](5, 6), Seq[Int](2, 1))) + + assert(ArrayUnion(a00, a01).dataType.asInstanceOf[ArrayType].containsNull === false) + assert(ArrayUnion(a00, a02).dataType.asInstanceOf[ArrayType].containsNull === true) + assert(ArrayUnion(a20, a21).dataType.asInstanceOf[ArrayType].containsNull === false) + assert(ArrayUnion(a20, a22).dataType.asInstanceOf[ArrayType].containsNull === true) } } From 4a217bc7e5fc534f460c8587f3d2509c04d71282 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sun, 8 Jul 2018 16:41:45 +0100 Subject: [PATCH 31/34] cleanup --- .../spark/sql/catalyst/expressions/collectionOperations.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index efbe97b219399..8684c9327799f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3409,7 +3409,7 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike new GenericArrayData(new Array[Any](elements)) } else { UnsafeArrayData.forPrimitiveArray( - Platform.INT_ARRAY_OFFSET, elements, IntegerType.defaultSize); + Platform.INT_ARRAY_OFFSET, elements, IntegerType.defaultSize) } evalIntLongPrimitiveType(array1, array2, elements, resultArray, false) case LongType => @@ -3440,7 +3440,7 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike new GenericArrayData(new Array[Any](elements)) } else { UnsafeArrayData.forPrimitiveArray( - Platform.LONG_ARRAY_OFFSET, elements, LongType.defaultSize); + Platform.LONG_ARRAY_OFFSET, elements, LongType.defaultSize) } evalIntLongPrimitiveType(array1, array2, elements, resultArray, true) case _ => From f5ebbe80e9461f3b2fb2b82cf9446c74de4e2e77 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sun, 8 Jul 2018 18:04:04 +0100 Subject: [PATCH 32/34] eliminate duplicated code --- .../expressions/collectionOperations.scala | 75 +++++++------------ 1 file changed, 26 insertions(+), 49 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 8684c9327799f..9f14f96f432d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3322,7 +3322,9 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike def assignInt(array: ArrayData, idx: Int, resultArray: ArrayData, pos: Int): Boolean = { val elem = array.getInt(idx) if (!hsInt.contains(elem)) { - resultArray.setInt(pos, elem) + if (resultArray != null) { + resultArray.setInt(pos, elem) + } hsInt.add(elem) true } else { @@ -3333,7 +3335,9 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike def assignLong(array: ArrayData, idx: Int, resultArray: ArrayData, pos: Int): Boolean = { val elem = array.getLong(idx) if (!hsLong.contains(elem)) { - resultArray.setLong(pos, elem) + if (resultArray != null) { + resultArray.setLong(pos, elem) + } hsLong.add(elem) true } else { @@ -3344,20 +3348,25 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike def evalIntLongPrimitiveType( array1: ArrayData, array2: ArrayData, - size: Int, resultArray: ArrayData, - isLongType: Boolean): ArrayData = { + isLongType: Boolean): Int = { // store elements into resultArray - var foundNullElement = false + var nullElementSize = 0 var pos = 0 Seq(array1, array2).foreach(array => { var i = 0 while (i < array.numElements()) { + val size = if (!isLongType) hsInt.size else hsLong.size + if (size + nullElementSize > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + ArraySetLike.throwUnionLengthOverflowException(size) + } if (array.isNullAt(i)) { - if (!foundNullElement) { - resultArray.setNullAt(pos) + if (nullElementSize == 0) { + if (resultArray != null) { + resultArray.setNullAt(pos) + } pos += 1 - foundNullElement = true + nullElementSize = 1 } } else { val assigned = if (!isLongType) { @@ -3372,7 +3381,7 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike i += 1 } }) - resultArray + pos } override def nullSafeEval(input1: Any, input2: Any): Any = { @@ -3384,25 +3393,8 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike case IntegerType => // avoid boxing of primitive int array elements // calculate result array size - val hsSize = new OpenHashSet[Int] - var nullElementSize = 0 - Seq(array1, array2).foreach { array => - var i = 0 - while (i < array.numElements()) { - if (hsSize.size + nullElementSize > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { - ArraySetLike.throwUnionLengthOverflowException(hsSize.size) - } - if (array.isNullAt(i)) { - if (nullElementSize == 0) { - nullElementSize = 1 - } - } else { - hsSize.add(array.getInt(i)) - } - i += 1 - } - } - val elements = hsSize.size + nullElementSize + hsInt = new OpenHashSet[Int] + val elements = evalIntLongPrimitiveType(array1, array2, null, false) hsInt = new OpenHashSet[Int] val resultArray = if (UnsafeArrayData.useGenericArrayData( IntegerType.defaultSize, elements)) { @@ -3411,29 +3403,13 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike UnsafeArrayData.forPrimitiveArray( Platform.INT_ARRAY_OFFSET, elements, IntegerType.defaultSize) } - evalIntLongPrimitiveType(array1, array2, elements, resultArray, false) + evalIntLongPrimitiveType(array1, array2, resultArray, false) + resultArray case LongType => // avoid boxing of primitive long array elements // calculate result array size - val hsSize = new OpenHashSet[Long] - var nullElementSize = 0 - Seq(array1, array2).foreach { array => - var i = 0 - while (i < array.numElements()) { - if (hsSize.size + nullElementSize > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { - ArraySetLike.throwUnionLengthOverflowException(hsSize.size) - } - if (array.isNullAt(i)) { - if (nullElementSize == 0) { - nullElementSize = 1 - } - } else { - hsSize.add(array.getLong(i)) - } - i += 1 - } - } - val elements = hsSize.size + nullElementSize + hsLong = new OpenHashSet[Long] + val elements = evalIntLongPrimitiveType(array1, array2, null, true) hsLong = new OpenHashSet[Long] val resultArray = if (UnsafeArrayData.useGenericArrayData( LongType.defaultSize, elements)) { @@ -3442,7 +3418,8 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike UnsafeArrayData.forPrimitiveArray( Platform.LONG_ARRAY_OFFSET, elements, LongType.defaultSize) } - evalIntLongPrimitiveType(array1, array2, elements, resultArray, true) + evalIntLongPrimitiveType(array1, array2, resultArray, true) + resultArray case _ => val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] val hs = new OpenHashSet[Any] From 763a1f87b37d2569c9dc4f1a626c3df0438fc42f Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 9 Jul 2018 20:21:02 +0100 Subject: [PATCH 33/34] address review comments --- .../catalyst/expressions/UnsafeArrayData.java | 2 +- .../expressions/collectionOperations.scala | 37 +++++++++---------- .../CollectionExpressionsSuite.scala | 17 ++++----- 3 files changed, 26 insertions(+), 30 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index a6f61163ca0aa..d7ac6b1730b47 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -477,7 +477,7 @@ public static UnsafeArrayData forPrimitiveArray(int offset, int length, int elem return fromPrimitiveArray(null, offset, length, elementSize); } - public static boolean useGenericArrayData(int elementSize, int length) { + public static boolean canUseGenericArrayData(int elementSize, int length) { final long headerInBytes = calculateHeaderPortionInBytes(length); final long valueRegionInBytes = (long)elementSize * length; final long totalSizeInLongs = (headerInBytes + valueRegionInBytes + 7) / 8; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 9f14f96f432d7..9327491bb7d59 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3262,23 +3262,13 @@ case class ArrayDistinct(child: Expression) override def prettyName: String = "array_distinct" } -object ArraySetLike { - def throwUnionLengthOverflowException(length: Int): Unit = { - throw new RuntimeException(s"Unsuccessful try to union arrays with $length " + - s"elements due to exceeding the array size limit " + - s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.") - } -} - - +/** + * Will become common base class for [[ArrayUnion]], ArrayIntersect, and ArrayExcept. + */ abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast { override def dataType: DataType = { - val dataTypes = children.map(_.dataType) - dataTypes.headOption.map { - case ArrayType(et, _) => - ArrayType(et, dataTypes.exists(_.asInstanceOf[ArrayType].containsNull)) - case dt => dt - }.getOrElse(StringType) + val dataTypes = children.map(_.dataType.asInstanceOf[ArrayType]) + ArrayType(elementType, dataTypes.exists(_.containsNull)) } override def checkInputDataTypes(): TypeCheckResult = { @@ -3301,6 +3291,15 @@ abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast { } } +object ArraySetLike { + def throwUnionLengthOverflowException(length: Int): Unit = { + throw new RuntimeException(s"Unsuccessful try to union arrays with $length " + + s"elements due to exceeding the array size limit " + + s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.") + } +} + + /** * Returns an array of the elements in the union of x and y, without duplicates */ @@ -3353,7 +3352,7 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike // store elements into resultArray var nullElementSize = 0 var pos = 0 - Seq(array1, array2).foreach(array => { + Seq(array1, array2).foreach { array => var i = 0 while (i < array.numElements()) { val size = if (!isLongType) hsInt.size else hsLong.size @@ -3380,7 +3379,7 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike } i += 1 } - }) + } pos } @@ -3396,7 +3395,7 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike hsInt = new OpenHashSet[Int] val elements = evalIntLongPrimitiveType(array1, array2, null, false) hsInt = new OpenHashSet[Int] - val resultArray = if (UnsafeArrayData.useGenericArrayData( + val resultArray = if (UnsafeArrayData.canUseGenericArrayData( IntegerType.defaultSize, elements)) { new GenericArrayData(new Array[Any](elements)) } else { @@ -3411,7 +3410,7 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike hsLong = new OpenHashSet[Long] val elements = evalIntLongPrimitiveType(array1, array2, null, true) hsLong = new OpenHashSet[Long] - val resultArray = if (UnsafeArrayData.useGenericArrayData( + val resultArray = if (UnsafeArrayData.canUseGenericArrayData( LongType.defaultSize, elements)) { new GenericArrayData(new Array[Any](elements)) } else { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 42ce7dc6fa0f5..bd0035e66ca8e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1171,7 +1171,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val a00 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) val a01 = Literal.create(Seq(4, 2), ArrayType(IntegerType, containsNull = false)) val a02 = Literal.create(Seq(1, 2, null, 4, 5), ArrayType(IntegerType, containsNull = true)) - val a03 = Literal.create(Seq(-5, 4, -3, 2, -1), ArrayType(IntegerType, containsNull = false)) + val a03 = Literal.create(Seq(-5, 4, -3, 2, 4), ArrayType(IntegerType, containsNull = false)) val a04 = Literal.create(Seq.empty[Int], ArrayType(IntegerType, containsNull = false)) val a05 = Literal.create(Seq[Byte](1, 2, 3), ArrayType(ByteType, containsNull = false)) val a06 = Literal.create(Seq[Byte](4, 2), ArrayType(ByteType, containsNull = false)) @@ -1191,17 +1191,14 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val a30 = Literal.create(Seq(null, null), ArrayType(IntegerType)) val a31 = Literal.create(null, ArrayType(StringType)) - checkEvaluation(ArrayUnion(a00, a01), UnsafeArrayData.fromPrimitiveArray(Array(1, 2, 3, 4))) - checkEvaluation(ArrayUnion(a02, a03), Seq(1, 2, null, 4, 5, -5, -3, -1)) - checkEvaluation(ArrayUnion(a03, a02), Seq(-5, 4, -3, 2, -1, 1, null, 5)) + checkEvaluation(ArrayUnion(a00, a01), Seq(1, 2, 3, 4)) + checkEvaluation(ArrayUnion(a02, a03), Seq(1, 2, null, 4, 5, -5, -3)) + checkEvaluation(ArrayUnion(a03, a02), Seq(-5, 4, -3, 2, 1, null, 5)) checkEvaluation(ArrayUnion(a02, a04), Seq(1, 2, null, 4, 5)) - checkEvaluation( - ArrayUnion(a05, a06), UnsafeArrayData.fromPrimitiveArray(Array[Byte](1, 2, 3, 4))) - checkEvaluation( - ArrayUnion(a07, a08), UnsafeArrayData.fromPrimitiveArray(Array[Short](1, 2, 3, 4))) + checkEvaluation(ArrayUnion(a05, a06), Seq[Byte](1, 2, 3, 4)) + checkEvaluation(ArrayUnion(a07, a08), Seq[Short](1, 2, 3, 4)) - checkEvaluation( - ArrayUnion(a10, a11), UnsafeArrayData.fromPrimitiveArray(Array(1L, 2L, 3L, 4L))) + checkEvaluation(ArrayUnion(a10, a11), Seq(1L, 2L, 3L, 4L)) checkEvaluation(ArrayUnion(a12, a13), Seq(1L, 2L, null, 4L, 5L, -5L, -3L, -1L)) checkEvaluation(ArrayUnion(a13, a12), Seq(-5L, 4L, -3L, 2L, -1L, 1L, null, 5L)) checkEvaluation(ArrayUnion(a12, a14), Seq(1L, 2L, null, 4L, 5L)) From 7b515649c75dc68d5d74ebca15628b6377a05344 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 11 Jul 2018 14:49:28 +0100 Subject: [PATCH 34/34] address review comment --- .../spark/sql/catalyst/expressions/UnsafeArrayData.java | 2 +- .../spark/sql/catalyst/expressions/collectionOperations.scala | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index d7ac6b1730b47..cf2a5ed2e27f9 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -477,7 +477,7 @@ public static UnsafeArrayData forPrimitiveArray(int offset, int length, int elem return fromPrimitiveArray(null, offset, length, elementSize); } - public static boolean canUseGenericArrayData(int elementSize, int length) { + public static boolean shouldUseGenericArrayData(int elementSize, int length) { final long headerInBytes = calculateHeaderPortionInBytes(length); final long valueRegionInBytes = (long)elementSize * length; final long totalSizeInLongs = (headerInBytes + valueRegionInBytes + 7) / 8; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 9327491bb7d59..3818b7072b353 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3395,7 +3395,7 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike hsInt = new OpenHashSet[Int] val elements = evalIntLongPrimitiveType(array1, array2, null, false) hsInt = new OpenHashSet[Int] - val resultArray = if (UnsafeArrayData.canUseGenericArrayData( + val resultArray = if (UnsafeArrayData.shouldUseGenericArrayData( IntegerType.defaultSize, elements)) { new GenericArrayData(new Array[Any](elements)) } else { @@ -3410,7 +3410,7 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike hsLong = new OpenHashSet[Long] val elements = evalIntLongPrimitiveType(array1, array2, null, true) hsLong = new OpenHashSet[Long] - val resultArray = if (UnsafeArrayData.canUseGenericArrayData( + val resultArray = if (UnsafeArrayData.shouldUseGenericArrayData( LongType.defaultSize, elements)) { new GenericArrayData(new Array[Any](elements)) } else {