From bc5a02dc419930f3790895614639868d7e5c4116 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 18 Apr 2018 17:01:50 +0100 Subject: [PATCH 01/21] initial commit --- python/pyspark/sql/functions.py | 19 ++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 229 ++++++++++++++++++ .../CollectionExpressionsSuite.scala | 41 ++++ .../org/apache/spark/sql/functions.scala | 11 + .../spark/sql/DataFrameFunctionsSuite.scala | 44 +++- 6 files changed, 343 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index ec014a5b39c31..560c3064d620b 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2033,6 +2033,25 @@ def array_distinct(col): return Column(sc._jvm.functions.array_distinct(_to_java_column(col))) +@ignore_unicode_prefix +@since(2.4) +def array_intersect(col1, col2): + """ + Collection function: returns an array of the elements in the intersection of col1 and col2, + without duplicates. The order of elements in the result is not determined. + + :param col1: name of column containing array + :param col2: name of column containing array + + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])]) + >>> df.select(array_intersect(df.c1, df.c2)).collect() + [Row(array_intersect(c1, c2)=[u'c', u'a']))] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_intersect(_to_java_column(col1), _to_java_column(col2))) + + @ignore_unicode_prefix @since(2.4) def array_union(col1, col2): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index d0efe975f81ce..10f89c9ee02db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -411,6 +411,7 @@ object FunctionRegistry { expression[CreateArray]("array"), expression[ArrayContains]("array_contains"), expression[ArraysOverlap]("arrays_overlap"), + expression[ArrayIntersect]("array_intersect"), expression[ArrayJoin]("array_join"), expression[ArrayPosition]("array_position"), expression[ArraySort]("array_sort"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 3f94f25796634..2d754e0fa2362 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3101,6 +3101,214 @@ object Sequence { } } +abstract class ArraySetUtils extends BinaryExpression with ExpectsInputTypes { + val kindUnion = 1 + val kindIntersect = 2 + def typeId: Int + + def array1: Expression + def array2: Expression + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, ArrayType) + + override def checkInputDataTypes(): TypeCheckResult = { + val r = super.checkInputDataTypes() + if ((r == TypeCheckResult.TypeCheckSuccess) && + (array1.dataType.asInstanceOf[ArrayType].elementType != + array2.dataType.asInstanceOf[ArrayType].elementType)) { + TypeCheckResult.TypeCheckFailure("Element type in both arrays must be the same") + } else { + r + } + } + + override def dataType: DataType = array1.dataType + + private def elementType = dataType.asInstanceOf[ArrayType].elementType + private def cn1 = array1.dataType.asInstanceOf[ArrayType].containsNull + private def cn2 = array2.dataType.asInstanceOf[ArrayType].containsNull + + override def nullSafeEval(input1: Any, input2: Any): Any = { + val ary1 = input1.asInstanceOf[ArrayData] + val ary2 = input2.asInstanceOf[ArrayData] + + if (!cn1 && !cn2) { + elementType match { + case IntegerType => + // avoid boxing of primitive int array elements + var hs: OpenHashSet[Int] = null + val hs1 = new OpenHashSet[Int] + var i = 0 + while (i < ary1.numElements()) { + hs1.add(ary1.getInt(i)) + i += 1 + } + if (typeId == kindUnion) { + i = 0 + while (i < ary2.numElements()) { + hs1.add(ary2.getInt(i)) + i += 1 + } + hs = hs1 + } else if (typeId == kindIntersect) { + hs = new OpenHashSet[Int] + i = 0 + while (i < ary2.numElements()) { + val k = ary2.getInt(i) + if (hs1.contains(k)) { + hs.add(k) + } + i += 1 + } + } + UnsafeArrayData.fromPrimitiveArray(hs.iterator.toArray) + case LongType => + // avoid boxing of primitive long array elements + var hs: OpenHashSet[Long] = null + val hs1 = new OpenHashSet[Long] + var i = 0 + while (i < ary1.numElements()) { + hs1.add(ary1.getLong(i)) + i += 1 + } + if (typeId == kindUnion) { + i = 0 + while (i < ary2.numElements()) { + hs1.add(ary2.getLong(i)) + i += 1 + } + hs = hs1 + } else if (typeId == kindIntersect) { + hs = new OpenHashSet[Long] + i = 0 + while (i < ary2.numElements()) { + val k = ary2.getLong(i) + if (hs1.contains(k)) { + hs.add(k) + } + i += 1 + } + } + UnsafeArrayData.fromPrimitiveArray(hs.iterator.toArray) + case _ => + var hs: OpenHashSet[Any] = null + val hs1 = new OpenHashSet[Any] + var i = 0 + while (i < ary1.numElements()) { + hs1.add(ary1.get(i, elementType)) + i += 1 + } + if (typeId == kindUnion) { + i = 0 + while (i < ary2.numElements()) { + hs1.add(ary2.get(i, elementType)) + i += 1 + } + hs = hs1 + } else if (typeId == kindIntersect) { + hs = new OpenHashSet[Any] + i = 0 + while (i < ary2.numElements()) { + val k = ary2.get(i, elementType) + if (hs1.contains(k)) { + hs.add(k) + } + i += 1 + } + } + new GenericArrayData(hs.iterator.toArray) + } + } else { + if (typeId == kindUnion) { + ArraySetUtils.arrayUnion(ary1, ary2, elementType) + } else if (typeId == kindIntersect) { + ArraySetUtils.arrayIntersect(ary1, ary2, elementType) + } + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val hs = ctx.freshName("hs") + val hs1 = ctx.freshName("hs1") + val i = ctx.freshName("i") + val ArraySetUtils = "org.apache.spark.sql.catalyst.expressions.ArraySetUtils" + val genericArrayData = classOf[GenericArrayData].getName + val unsafeArrayData = classOf[UnsafeArrayData].getName + val openHashSet = classOf[OpenHashSet[_]].getName + val et = s"org.apache.spark.sql.types.DataTypes.$elementType" + val (postFix, classTag, getter, arrayBuilder, castType) = if (!cn1 && !cn2) { + val ptName = CodeGenerator.primitiveTypeName(elementType) + elementType match { + case ByteType | ShortType | IntegerType => + (s"$$mcI$$sp", s"scala.reflect.ClassTag$$.MODULE$$.$ptName()", s"get$ptName($i)", + s"$unsafeArrayData.fromPrimitiveArray", CodeGenerator.javaType(elementType)) + case LongType => + (s"$$mcJ$$sp", s"scala.reflect.ClassTag$$.MODULE$$.$ptName()", s"get$ptName($i)", + s"$unsafeArrayData.fromPrimitiveArray", "long") + case _ => + ("", s"scala.reflect.ClassTag$$.MODULE$$.Object()", s"get($i, $et)", + s"new $genericArrayData", "Object") + } + } else { + ("", "", "", "", "") + } + + nullSafeCodeGen(ctx, ev, (ary1, ary2) => { + if (classTag != "") { + if (typeId == kindUnion) { + s""" + |$openHashSet $hs = new $openHashSet$postFix($classTag); + |for (int $i = 0; $i < $ary1.numElements(); $i++) { + | $hs.add$postFix($ary1.$getter); + |} + |for (int $i = 0; $i < $ary2.numElements(); $i++) { + | $hs.add$postFix($ary2.$getter); + |} + |${ev.value} = $arrayBuilder(($castType[]) $hs.iterator().toArray($classTag)); + """.stripMargin + } else if (typeId == kindIntersect) { + s""" + |$openHashSet $hs1 = new $openHashSet$postFix($classTag); + |for (int $i = 0; $i < $ary1.numElements(); $i++) { + | $hs1.add$postFix($ary1.$getter); + |} + |$openHashSet $hs = new $openHashSet$postFix($classTag); + |for (int $i = 0; $i < $ary2.numElements(); $i++) { + | if ($hs1.contains$postFix($ary2.$getter)) { + | $hs.add$postFix($ary2.$getter); + | } + |} + |${ev.value} = $arrayBuilder(($castType[]) $hs.iterator().toArray($classTag)); + """.stripMargin + } else { + throw new UnsupportedOperationException("typeId=" + typeId + "is not supported.") + } + } else { + val setOp = if (typeId == kindUnion) { + "Union" + } else if (typeId == kindIntersect) { + "Intersect" + } else { + throw new UnsupportedOperationException("typeId=" + typeId + "is not supported.") + } + s"${ev.value} = $ArraySetUtils$$.MODULE$$.array$setOp($ary1, $ary2, $et);" + } + }) + } +} + +object ArraySetUtils { + def arrayUnion(array1: ArrayData, array2: ArrayData, et: DataType): ArrayData = { + new GenericArrayData(array1.toArray[AnyRef](et).union(array2.toArray[AnyRef](et)) + .distinct.asInstanceOf[Array[Any]]) + } + + def arrayIntersect(array1: ArrayData, array2: ArrayData, et: DataType): ArrayData = { + new GenericArrayData(array1.toArray[AnyRef](et).intersect(array2.toArray[AnyRef](et)) + .distinct.asInstanceOf[Array[Any]]) + } +} + /** * Returns the array containing the given input value (left) count (right) times. */ @@ -3963,6 +4171,27 @@ object ArrayUnion { })) new GenericArrayData(arrayBuffer) } + + /** + * Returns an array of the elements in the intersect of x and y, without duplicates + */ + @ExpressionDescription( + usage = """ + _FUNC_(array1, array2) - Returns an array of the elements in the intersection of array1 and + array2, without duplicates. The order of elements in the result is not determined. + """, + examples = """ + Examples:Fun + > SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5)); + array(1, 3) + """, + since = "2.4.0") +case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetUtils { + override def typeId: Int = kindIntersect + override def array1: Expression = left + override def array2: Expression = right + + override def prettyName: String = "array_intersect" } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 2f6f9064f9e62..d7e81b67138fb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1618,4 +1618,45 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper assert(ArrayExcept(a20, a21).dataType.asInstanceOf[ArrayType].containsNull === false) assert(ArrayExcept(a24, a22).dataType.asInstanceOf[ArrayType].containsNull === true) } + + test("Array Intersect") { + val a00 = Literal.create(Seq(1, 2, 4), ArrayType(IntegerType, false)) + val a01 = Literal.create(Seq(4, 2), ArrayType(IntegerType, false)) + val a02 = Literal.create(Seq(1, 2, 4), ArrayType(IntegerType)) + val a03 = Literal.create(Seq(1, 2, null, 4, 5), ArrayType(IntegerType)) + val a04 = Literal.create(Seq(-5, 4, null, 2, -1), ArrayType(IntegerType)) + val a05 = Literal.create(Seq.empty[Int], ArrayType(IntegerType)) + + val a10 = Literal.create(Seq(1L, 2L, 4L), ArrayType(LongType, false)) + val a11 = Literal.create(Seq(4L, 2L), ArrayType(LongType, false)) + val a12 = Literal.create(Seq(1L, 2L, 4L), ArrayType(LongType)) + val a13 = Literal.create(Seq(1L, 2L, null, 4L, 5L), ArrayType(LongType)) + val a14 = Literal.create(Seq(-5L, 4L, null, 2L, -1L), ArrayType(LongType)) + val a15 = Literal.create(Seq.empty[Long], ArrayType(LongType)) + + val a20 = Literal.create(Seq("b", "a", "c"), ArrayType(StringType)) + val a21 = Literal.create(Seq("c", null, "a", "f"), ArrayType(StringType)) + val a22 = Literal.create(Seq("b", null, "a", "g"), ArrayType(StringType)) + val a23 = Literal.create(Seq("b", "a", "c"), ArrayType(StringType, false)) + val a24 = Literal.create(Seq("c", "d", "a", "f"), ArrayType(StringType, false)) + + val a30 = Literal.create(Seq(null, null), ArrayType(NullType)) + + checkEvaluation(ArrayIntersect(a00, a01), UnsafeArrayData.fromPrimitiveArray(Array(4, 2))) + checkEvaluation(ArrayIntersect(a01, a02), Seq(4, 2)) + checkEvaluation(ArrayIntersect(a03, a04), Seq(2, null, 4)) + checkEvaluation(ArrayIntersect(a03, a05), Seq.empty) + + checkEvaluation( + ArrayIntersect(a10, a11), UnsafeArrayData.fromPrimitiveArray(Array(4L, 2L))) + checkEvaluation(ArrayIntersect(a11, a12), Seq(4L, 2L)) + checkEvaluation(ArrayIntersect(a13, a14), Seq(2L, null, 4L)) + checkEvaluation(ArrayIntersect(a13, a15), Seq.empty) + + checkEvaluation(ArrayIntersect(a20, a21), Seq("a", "c")) + checkEvaluation(ArrayIntersect(a21, a22), Seq(null, "a")) + checkEvaluation(ArrayIntersect(a23, a24), Seq("c", "a")) + + checkEvaluation(ArrayIntersect(a30, a30), Seq(null)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index cc739b85f555c..310e428b69819 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3233,6 +3233,17 @@ object functions { */ def array_distinct(e: Column): Column = withExpr { ArrayDistinct(e.expr) } + /** + * Returns an array of the elements in the intersection of the given two arrays, + * without duplicates. + * + * @group collection_funcs + * @since 2.4.0 + */ + def array_intersect(col1: Column, col2: Column): Column = withExpr { + ArrayIntersect(col1.expr, col2.expr) + } + /** * Returns an array of the elements in the union of the given two arrays, without duplicates. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 1d5707a2c7047..eb08635fc7689 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1173,8 +1173,8 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } test("concat function - arrays") { - val nseqi : Seq[Int] = null - val nseqs : Seq[String] = null + val nseqi: Seq[Int] = null + val nseqs: Seq[String] = null val df = Seq( (Seq(1), Seq(2, 3), Seq(5L, 6L), nseqi, Seq("a", "b", "c"), Seq("d", "e"), Seq("f"), nseqs), (Seq(1, 0), Seq.empty[Int], Seq(2L), nseqi, Seq("a"), Seq.empty[String], Seq(null), nseqs) @@ -1647,6 +1647,46 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(result10.first.schema(0).dataType === expectedType10) } + test("array_intersect functions") { + val df1 = Seq((Array(1, 2, 4), Array(4, 2))).toDF("a", "b") + val ans1 = Row(Seq(4, 2)) + checkAnswer(df1.select(array_intersect($"a", $"b")), ans1) + checkAnswer(df1.selectExpr("array_intersect(a, b)"), ans1) + + val df2 = Seq((Array[Integer](1, 2, null, 4, 5), Array[Integer](-5, 4, null, 2, -1))) + .toDF("a", "b") + val ans2 = Row(Seq(2, null, 4)) + checkAnswer(df2.select(array_intersect($"a", $"b")), ans2) + checkAnswer(df2.selectExpr("array_intersect(a, b)"), ans2) + + val df3 = Seq((Array(1L, 2L, 4L), Array(4L, 2L))).toDF("a", "b") + val ans3 = Row(Seq(4L, 2L)) + checkAnswer(df3.select(array_intersect($"a", $"b")), ans3) + checkAnswer(df3.selectExpr("array_intersect(a, b)"), ans3) + + val ans4 = Row(Seq(2L, null, 4L)) + checkAnswer(df4.select(array_intersect($"a", $"b")), ans4) + checkAnswer(df4.selectExpr("array_intersect(a, b)"), ans4) + + val df5 = Seq((Array("c", null, "a", "f"), Array("b", null, "a", "g"))).toDF("a", "b") + val ans5 = Row(Seq(null, "a")) + checkAnswer(df5.select(array_intersect($"a", $"b")), ans5) + checkAnswer(df5.selectExpr("array_intersect(a, b)"), ans5) + + val df6 = Seq((null, null)).toDF("a", "b") + val ans6 = Row(null) + checkAnswer(df6.select(array_intersect($"a", $"b")), ans6) + checkAnswer(df6.selectExpr("array_intersect(a, b)"), ans6) + + val df0 = Seq((Array(1), Array("a"))).toDF("a", "b") + intercept[AnalysisException] { + df0.select(array_intersect($"a", $"b")) + } + intercept[AnalysisException] { + df0.selectExpr("array_intersect(a, b)") + } + } + test("transform function - array for primitive type not containing null") { val df = Seq( Seq(1, 9, 8, 7), From c7f230deaa10783ac79d6c6c79f99de434d43814 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 18 Apr 2018 19:22:48 +0100 Subject: [PATCH 02/21] simplification --- .../expressions/collectionOperations.scala | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 2d754e0fa2362..49975d63f28fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3106,27 +3106,24 @@ abstract class ArraySetUtils extends BinaryExpression with ExpectsInputTypes { val kindIntersect = 2 def typeId: Int - def array1: Expression - def array2: Expression - override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, ArrayType) override def checkInputDataTypes(): TypeCheckResult = { val r = super.checkInputDataTypes() if ((r == TypeCheckResult.TypeCheckSuccess) && - (array1.dataType.asInstanceOf[ArrayType].elementType != - array2.dataType.asInstanceOf[ArrayType].elementType)) { + (left.dataType.asInstanceOf[ArrayType].elementType != + right.dataType.asInstanceOf[ArrayType].elementType)) { TypeCheckResult.TypeCheckFailure("Element type in both arrays must be the same") } else { r } } - override def dataType: DataType = array1.dataType + override def dataType: DataType = left.dataType private def elementType = dataType.asInstanceOf[ArrayType].elementType - private def cn1 = array1.dataType.asInstanceOf[ArrayType].containsNull - private def cn2 = array2.dataType.asInstanceOf[ArrayType].containsNull + private def cn1 = left.dataType.asInstanceOf[ArrayType].containsNull + private def cn2 = right.dataType.asInstanceOf[ArrayType].containsNull override def nullSafeEval(input1: Any, input2: Any): Any = { val ary1 = input1.asInstanceOf[ArrayData] @@ -4188,8 +4185,6 @@ object ArrayUnion { since = "2.4.0") case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetUtils { override def typeId: Int = kindIntersect - override def array1: Expression = left - override def array2: Expression = right override def prettyName: String = "array_intersect" } From e96ab2df0b22826cc9ab10ae32e7bd66be9f1786 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 19 Apr 2018 03:05:34 +0100 Subject: [PATCH 03/21] fix pyspark test failure --- python/pyspark/sql/functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 560c3064d620b..42fca146f3ba9 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2046,7 +2046,7 @@ def array_intersect(col1, col2): >>> from pyspark.sql import Row >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])]) >>> df.select(array_intersect(df.c1, df.c2)).collect() - [Row(array_intersect(c1, c2)=[u'c', u'a']))] + [Row(array_intersect(c1, c2)=[u'a', u'c'])] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.array_intersect(_to_java_column(col1), _to_java_column(col2))) From 49bc3192d09539f288134f09c3b2d18b5d14059b Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 10 Jul 2018 12:17:45 +0100 Subject: [PATCH 04/21] update --- .../spark/util/collection/OpenHashSet.scala | 11 + .../util/collection/OpenHashSetSuite.scala | 63 +++ .../expressions/collectionOperations.scala | 525 +++++++++++------- .../CollectionExpressionsSuite.scala | 115 +++- .../spark/sql/DataFrameFunctionsSuite.scala | 30 +- 5 files changed, 499 insertions(+), 245 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala index 8883e17bf3164..739c862ed60e4 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala @@ -118,6 +118,17 @@ class OpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag]( rehashIfNeeded(k, grow, move) } + /** + * Remove an element from the set. If an element does not exists in the set, nothing is done. + */ + def remove(k: T): Unit = { + val pos = hashcode(hasher.hash(k)) & _mask + if (_bitset.get(pos)) { + _bitset.unset(pos) + _size -= 1 + } + } + def union(other: OpenHashSet[T]): OpenHashSet[T] = { val iterator = other.iterator while (iterator.hasNext) { diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala index b887f937a9da9..6e9fb9f6f71f4 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala @@ -73,6 +73,27 @@ class OpenHashSetSuite extends SparkFunSuite with Matchers { assert(set.contains(50)) assert(set.contains(999)) assert(!set.contains(10000)) + + set.remove(999) + assert(set.size === 2) + assert(set.contains(10)) + assert(set.contains(50)) + assert(!set.contains(999)) + assert(!set.contains(10000)) + + set.add(999) + assert(set.size === 3) + assert(set.contains(10)) + assert(set.contains(50)) + assert(set.contains(999)) + assert(!set.contains(10000)) + + set.remove(10000) + assert(set.size === 3) + assert(set.contains(10)) + assert(set.contains(50)) + assert(set.contains(999)) + assert(!set.contains(10000)) } test("primitive long") { @@ -110,6 +131,27 @@ class OpenHashSetSuite extends SparkFunSuite with Matchers { assert(set.contains(50L)) assert(set.contains(999L)) assert(!set.contains(10000L)) + + set.remove(999L) + assert(set.size === 2) + assert(set.contains(10L)) + assert(set.contains(50L)) + assert(!set.contains(999L)) + assert(!set.contains(10000L)) + + set.add(999L) + assert(set.size === 3) + assert(set.contains(10L)) + assert(set.contains(50L)) + assert(set.contains(999L)) + assert(!set.contains(10000L)) + + set.remove(10000L) + assert(set.size === 3) + assert(set.contains(10L)) + assert(set.contains(50L)) + assert(!set.contains(999L)) + assert(!set.contains(10000L)) } test("primitive float") { @@ -221,6 +263,27 @@ class OpenHashSetSuite extends SparkFunSuite with Matchers { assert(set.contains(50.toString)) assert(set.contains(999.toString)) assert(!set.contains(10000.toString)) + + set.remove(999.toString) + assert(set.size === 2) + assert(set.contains(10.toString)) + assert(set.contains(50.toString)) + assert(!set.contains(999.toString)) + assert(!set.contains(10000.toString)) + + set.add(999.toString) + assert(set.size === 3) + assert(set.contains(10.toString)) + assert(set.contains(50.toString)) + assert(set.contains(999.toString)) + assert(!set.contains(10000.toString)) + + set.remove(10000.toString) + assert(set.size === 3) + assert(set.contains(10.toString)) + assert(set.contains(50.toString)) + assert(set.contains(999.toString)) + assert(!set.contains(10000.toString)) } test("non-primitive set growth") { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 49975d63f28fb..171804e4e8a4b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3101,211 +3101,6 @@ object Sequence { } } -abstract class ArraySetUtils extends BinaryExpression with ExpectsInputTypes { - val kindUnion = 1 - val kindIntersect = 2 - def typeId: Int - - override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, ArrayType) - - override def checkInputDataTypes(): TypeCheckResult = { - val r = super.checkInputDataTypes() - if ((r == TypeCheckResult.TypeCheckSuccess) && - (left.dataType.asInstanceOf[ArrayType].elementType != - right.dataType.asInstanceOf[ArrayType].elementType)) { - TypeCheckResult.TypeCheckFailure("Element type in both arrays must be the same") - } else { - r - } - } - - override def dataType: DataType = left.dataType - - private def elementType = dataType.asInstanceOf[ArrayType].elementType - private def cn1 = left.dataType.asInstanceOf[ArrayType].containsNull - private def cn2 = right.dataType.asInstanceOf[ArrayType].containsNull - - override def nullSafeEval(input1: Any, input2: Any): Any = { - val ary1 = input1.asInstanceOf[ArrayData] - val ary2 = input2.asInstanceOf[ArrayData] - - if (!cn1 && !cn2) { - elementType match { - case IntegerType => - // avoid boxing of primitive int array elements - var hs: OpenHashSet[Int] = null - val hs1 = new OpenHashSet[Int] - var i = 0 - while (i < ary1.numElements()) { - hs1.add(ary1.getInt(i)) - i += 1 - } - if (typeId == kindUnion) { - i = 0 - while (i < ary2.numElements()) { - hs1.add(ary2.getInt(i)) - i += 1 - } - hs = hs1 - } else if (typeId == kindIntersect) { - hs = new OpenHashSet[Int] - i = 0 - while (i < ary2.numElements()) { - val k = ary2.getInt(i) - if (hs1.contains(k)) { - hs.add(k) - } - i += 1 - } - } - UnsafeArrayData.fromPrimitiveArray(hs.iterator.toArray) - case LongType => - // avoid boxing of primitive long array elements - var hs: OpenHashSet[Long] = null - val hs1 = new OpenHashSet[Long] - var i = 0 - while (i < ary1.numElements()) { - hs1.add(ary1.getLong(i)) - i += 1 - } - if (typeId == kindUnion) { - i = 0 - while (i < ary2.numElements()) { - hs1.add(ary2.getLong(i)) - i += 1 - } - hs = hs1 - } else if (typeId == kindIntersect) { - hs = new OpenHashSet[Long] - i = 0 - while (i < ary2.numElements()) { - val k = ary2.getLong(i) - if (hs1.contains(k)) { - hs.add(k) - } - i += 1 - } - } - UnsafeArrayData.fromPrimitiveArray(hs.iterator.toArray) - case _ => - var hs: OpenHashSet[Any] = null - val hs1 = new OpenHashSet[Any] - var i = 0 - while (i < ary1.numElements()) { - hs1.add(ary1.get(i, elementType)) - i += 1 - } - if (typeId == kindUnion) { - i = 0 - while (i < ary2.numElements()) { - hs1.add(ary2.get(i, elementType)) - i += 1 - } - hs = hs1 - } else if (typeId == kindIntersect) { - hs = new OpenHashSet[Any] - i = 0 - while (i < ary2.numElements()) { - val k = ary2.get(i, elementType) - if (hs1.contains(k)) { - hs.add(k) - } - i += 1 - } - } - new GenericArrayData(hs.iterator.toArray) - } - } else { - if (typeId == kindUnion) { - ArraySetUtils.arrayUnion(ary1, ary2, elementType) - } else if (typeId == kindIntersect) { - ArraySetUtils.arrayIntersect(ary1, ary2, elementType) - } - } - } - - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val hs = ctx.freshName("hs") - val hs1 = ctx.freshName("hs1") - val i = ctx.freshName("i") - val ArraySetUtils = "org.apache.spark.sql.catalyst.expressions.ArraySetUtils" - val genericArrayData = classOf[GenericArrayData].getName - val unsafeArrayData = classOf[UnsafeArrayData].getName - val openHashSet = classOf[OpenHashSet[_]].getName - val et = s"org.apache.spark.sql.types.DataTypes.$elementType" - val (postFix, classTag, getter, arrayBuilder, castType) = if (!cn1 && !cn2) { - val ptName = CodeGenerator.primitiveTypeName(elementType) - elementType match { - case ByteType | ShortType | IntegerType => - (s"$$mcI$$sp", s"scala.reflect.ClassTag$$.MODULE$$.$ptName()", s"get$ptName($i)", - s"$unsafeArrayData.fromPrimitiveArray", CodeGenerator.javaType(elementType)) - case LongType => - (s"$$mcJ$$sp", s"scala.reflect.ClassTag$$.MODULE$$.$ptName()", s"get$ptName($i)", - s"$unsafeArrayData.fromPrimitiveArray", "long") - case _ => - ("", s"scala.reflect.ClassTag$$.MODULE$$.Object()", s"get($i, $et)", - s"new $genericArrayData", "Object") - } - } else { - ("", "", "", "", "") - } - - nullSafeCodeGen(ctx, ev, (ary1, ary2) => { - if (classTag != "") { - if (typeId == kindUnion) { - s""" - |$openHashSet $hs = new $openHashSet$postFix($classTag); - |for (int $i = 0; $i < $ary1.numElements(); $i++) { - | $hs.add$postFix($ary1.$getter); - |} - |for (int $i = 0; $i < $ary2.numElements(); $i++) { - | $hs.add$postFix($ary2.$getter); - |} - |${ev.value} = $arrayBuilder(($castType[]) $hs.iterator().toArray($classTag)); - """.stripMargin - } else if (typeId == kindIntersect) { - s""" - |$openHashSet $hs1 = new $openHashSet$postFix($classTag); - |for (int $i = 0; $i < $ary1.numElements(); $i++) { - | $hs1.add$postFix($ary1.$getter); - |} - |$openHashSet $hs = new $openHashSet$postFix($classTag); - |for (int $i = 0; $i < $ary2.numElements(); $i++) { - | if ($hs1.contains$postFix($ary2.$getter)) { - | $hs.add$postFix($ary2.$getter); - | } - |} - |${ev.value} = $arrayBuilder(($castType[]) $hs.iterator().toArray($classTag)); - """.stripMargin - } else { - throw new UnsupportedOperationException("typeId=" + typeId + "is not supported.") - } - } else { - val setOp = if (typeId == kindUnion) { - "Union" - } else if (typeId == kindIntersect) { - "Intersect" - } else { - throw new UnsupportedOperationException("typeId=" + typeId + "is not supported.") - } - s"${ev.value} = $ArraySetUtils$$.MODULE$$.array$setOp($ary1, $ary2, $et);" - } - }) - } -} - -object ArraySetUtils { - def arrayUnion(array1: ArrayData, array2: ArrayData, et: DataType): ArrayData = { - new GenericArrayData(array1.toArray[AnyRef](et).union(array2.toArray[AnyRef](et)) - .distinct.asInstanceOf[Array[Any]]) - } - - def arrayIntersect(array1: ArrayData, array2: ArrayData, et: DataType): ArrayData = { - new GenericArrayData(array1.toArray[AnyRef](et).intersect(array2.toArray[AnyRef](et)) - .distinct.asInstanceOf[Array[Any]]) - } -} - /** * Returns the array containing the given input value (left) count (right) times. */ @@ -3856,7 +3651,7 @@ case class ArrayDistinct(child: Expression) } /** - * Will become common base class for [[ArrayUnion]], ArrayIntersect, and [[ArrayExcept]]. + * Will become common base class for [[ArrayUnion]], [[ArrayIntersect]], and [[ArrayExcept]]. */ abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast { override def checkInputDataTypes(): TypeCheckResult = { @@ -4175,7 +3970,7 @@ object ArrayUnion { @ExpressionDescription( usage = """ _FUNC_(array1, array2) - Returns an array of the elements in the intersection of array1 and - array2, without duplicates. The order of elements in the result is not determined. + array2, without duplicates. """, examples = """ Examples:Fun @@ -4183,8 +3978,261 @@ object ArrayUnion { array(1, 3) """, since = "2.4.0") -case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetUtils { - override def typeId: Int = kindIntersect +case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetLike { + var hsInt: OpenHashSet[Int] = _ + var hsLong: OpenHashSet[Long] = _ + + def assignInt(array: ArrayData, idx: Int, resultArray: ArrayData, pos: Int): Boolean = { + val elem = array.getInt(idx) + if (hsInt.contains(elem)) { + if (resultArray != null) { + resultArray.setInt(pos, elem) + } + hsInt.remove(elem) + true + } else { + false + } + } + + def assignLong(array: ArrayData, idx: Int, resultArray: ArrayData, pos: Int): Boolean = { + val elem = array.getLong(idx) + if (hsLong.contains(elem)) { + if (resultArray != null) { + resultArray.setLong(pos, elem) + } + hsLong.remove(elem) + true + } else { + false + } + } + + def evalIntLongPrimitiveType( + array1: ArrayData, + array2: ArrayData, + resultArray: ArrayData, + isLongType: Boolean): Int = { + // store elements into resultArray + var foundNullElement = false + var i = 0 + while (i < array1.numElements()) { + if (array1.isNullAt(i)) { + foundNullElement = true + } else { + val assigned = if (!isLongType) { + hsInt.add(array1.getInt(i)) + } else { + hsLong.add(array1.getLong(i)) + } + } + i += 1 + } + var pos = 0 + i = 0 + while (i < array2.numElements()) { + if (array2.isNullAt(i)) { + if (foundNullElement) { + if (resultArray != null) { + resultArray.setNullAt(pos) + } + pos += 1 + foundNullElement = false + } + } else { + val assigned = if (!isLongType) { + assignInt(array2, i, resultArray, pos) + } else { + assignLong(array2, i, resultArray, pos) + } + if (assigned) { + pos += 1 + } + } + i += 1 + } + pos + } + + override def nullSafeEval(input1: Any, input2: Any): Any = { + val array1 = input1.asInstanceOf[ArrayData] + val array2 = input2.asInstanceOf[ArrayData] + + if (elementTypeSupportEquals) { + elementType match { + case IntegerType => + // avoid boxing of primitive int array elements + // calculate result array size + hsInt = new OpenHashSet[Int] + val elements = evalIntLongPrimitiveType(array1, array2, null, false) + // allocate result array + hsInt = new OpenHashSet[Int] + val resultArray = if (UnsafeArrayData.canUseGenericArrayData( + IntegerType.defaultSize, elements)) { + new GenericArrayData(new Array[Any](elements)) + } else { + UnsafeArrayData.forPrimitiveArray( + Platform.INT_ARRAY_OFFSET, elements, IntegerType.defaultSize) + } + // assign elements into the result array + evalIntLongPrimitiveType(array1, array2, resultArray, false) + resultArray + case LongType => + // avoid boxing of primitive long array elements + // calculate result array size + hsLong = new OpenHashSet[Long] + val elements = evalIntLongPrimitiveType(array1, array2, null, true) + // allocate result array + hsLong = new OpenHashSet[Long] + val resultArray = if (UnsafeArrayData.canUseGenericArrayData( + LongType.defaultSize, elements)) { + new GenericArrayData(new Array[Any](elements)) + } else { + UnsafeArrayData.forPrimitiveArray( + Platform.LONG_ARRAY_OFFSET, elements, LongType.defaultSize) + } + // assign elements into the result array + evalIntLongPrimitiveType(array1, array2, resultArray, true) + resultArray + case _ => + val hs = new OpenHashSet[Any] + var foundNullElement = false + var i = 0 + while (i < array1.numElements()) { + if (array1.isNullAt(i)) { + foundNullElement = true + } else { + val elem = array1.get(i, elementType) + hs.add(elem) + } + i += 1 + } + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + i = 0 + while (i < array2.numElements()) { + if (array2.isNullAt(i)) { + if (foundNullElement) { + arrayBuffer += null + foundNullElement = false + } + } else { + val elem = array2.get(i, elementType) + if (hs.contains(elem)) { + arrayBuffer += elem + hs.remove(elem) + } + } + i += 1 + } + new GenericArrayData(arrayBuffer) + } + } else { + ArrayIntersect.intersectOrdering(array1, array2, elementType, ordering) + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val i = ctx.freshName("i") + val pos = ctx.freshName("pos") + val value = ctx.freshName("value") + val size = ctx.freshName("size") + val (postFix, openHashElementType, getter, setter, javaTypeName, castOp, arrayBuilder) = + if (elementTypeSupportEquals) { + elementType match { + case ByteType | ShortType | IntegerType | LongType => + val ptName = CodeGenerator.primitiveTypeName(elementType) + val unsafeArray = ctx.freshName("unsafeArray") + (if (elementType == LongType) s"$$mcJ$$sp" else s"$$mcI$$sp", + if (elementType == LongType) "Long" else "Int", + s"get$ptName($i)", s"set$ptName($pos, $value)", CodeGenerator.javaType(elementType), + if (elementType == LongType) "(long)" else "(int)", + s""" + |${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")} + |${ev.value} = $unsafeArray; + """.stripMargin) + case _ => + val genericArrayData = classOf[GenericArrayData].getName + val et = ctx.addReferenceObj("elementType", elementType) + ("", "Object", + s"get($i, $et)", s"update($pos, $value)", "Object", "", + s"${ev.value} = new $genericArrayData(new Object[$size]);") + } + } else { + ("", "", "", "", "", "", "") + } + + nullSafeCodeGen(ctx, ev, (array1, array2) => { + if (openHashElementType != "") { + // Here, we ensure elementTypeSupportEquals is true + val foundNullElement = ctx.freshName("foundNullElement") + val openHashSet = classOf[OpenHashSet[_]].getName + val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$openHashElementType()" + val hs = ctx.freshName("hs") + val arrayData = classOf[ArrayData].getName + val arrays = ctx.freshName("arrays") + val array = ctx.freshName("array") + val arrayDataIdx = ctx.freshName("arrayDataIdx") + s""" + |$openHashSet $hs = new $openHashSet$postFix($classTag); + |boolean $foundNullElement = false; + |int $size = 0; + |for (int $i = 0; $i < $array1.numElements(); $i++) { + | if ($array1.isNullAt($i)) { + | $foundNullElement = true; + | } else { + | $hs.add$postFix($array1.$getter); + | } + |} + |for (int $i = 0; $i < $array2.numElements(); $i++) { + | if ($array2.isNullAt($i)) { + | if ($foundNullElement) { + | $size++; + | $foundNullElement = false; + | } + | } else { + | $javaTypeName $value = $array2.$getter; + | if ($hs.contains($castOp $value)) { + | $hs.remove$postFix($value); + | $size++; + | } + | } + |} + |$arrayBuilder + |$hs = new $openHashSet$postFix($classTag); + |$foundNullElement = false; + |int $pos = 0; + |for (int $i = 0; $i < $array1.numElements(); $i++) { + | if ($array1.isNullAt($i)) { + | $foundNullElement = true; + | } else { + | $hs.add$postFix($array1.$getter); + | } + |} + |for (int $i = 0; $i < $array2.numElements(); $i++) { + | if ($array2.isNullAt($i)) { + | if ($foundNullElement) { + | ${ev.value}.setNullAt($pos++); + | $foundNullElement = false; + | } + | } else { + | $javaTypeName $value = $array2.$getter; + | if ($hs.contains($castOp $value)) { + | $hs.remove$postFix($value); + | ${ev.value}.$setter; + | $pos++; + | } + | } + |} + """.stripMargin + } else { + val arrayIntersect = classOf[ArrayIntersect].getName + val et = ctx.addReferenceObj("elementTypeIntersect", elementType) + val order = ctx.addReferenceObj("orderingIntersect", ordering) + val method = "intersectOrdering" + s"${ev.value} = $arrayIntersect$$.MODULE$$.$method($array1, $array2, $et, $order);" + } + }) + } override def prettyName: String = "array_intersect" } @@ -4289,7 +4337,7 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike i += 1 } new GenericArrayData(arrayBuffer) - } + } } override def nullSafeEval(input1: Any, input2: Any): Any = { @@ -4471,3 +4519,60 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike override def prettyName: String = "array_except" } + +object ArrayIntersect { + def intersectOrdering( + array1: ArrayData, + array2: ArrayData, + elementType: DataType, + ordering: Ordering[Any]): ArrayData = { + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + var alreadySeenNull = false + var i = 0 + while (i < array1.numElements()) { + var found = false + val elem1 = array1.get(i, elementType) + if (array1.isNullAt(i)) { + if (!alreadySeenNull) { + var j = 0 + while (!found && j < array2.numElements()) { + if (array2.isNullAt(j)) { + found = true + } + j += 1 + } + // array2 is scaned only once for null element + alreadySeenNull = true + } + } else { + var j = 0 + while (!found && j < array2.numElements()) { + if (!array2.isNullAt(j)) { + val elem2 = array2.get(j, elementType) + if (ordering.equiv(elem1, elem2)) { + // check whether elem2 is already stored in arrayBuffer + var foundArrayBuffer = false + var k = 0 + while (!foundArrayBuffer && k < arrayBuffer.size) { + val va = arrayBuffer(k) + if (va != null && ordering.equiv(va, elem2)) { + foundArrayBuffer = true + } + k += 1 + } + if (!foundArrayBuffer) { + found = true + } + } + } + j += 1 + } + } + if (found) { + arrayBuffer += elem1 + } + i += 1 + } + new GenericArrayData(arrayBuffer) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index d7e81b67138fb..2945b055a647b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1622,41 +1622,102 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper test("Array Intersect") { val a00 = Literal.create(Seq(1, 2, 4), ArrayType(IntegerType, false)) val a01 = Literal.create(Seq(4, 2), ArrayType(IntegerType, false)) - val a02 = Literal.create(Seq(1, 2, 4), ArrayType(IntegerType)) - val a03 = Literal.create(Seq(1, 2, null, 4, 5), ArrayType(IntegerType)) - val a04 = Literal.create(Seq(-5, 4, null, 2, -1), ArrayType(IntegerType)) - val a05 = Literal.create(Seq.empty[Int], ArrayType(IntegerType)) + val a02 = Literal.create(Seq(1, 2, 1, 4), ArrayType(IntegerType, false)) + val a03 = Literal.create(Seq(4, 2, 4), ArrayType(IntegerType, false)) + val a04 = Literal.create(Seq(1, 2, null, 4, 5, null), ArrayType(IntegerType, true)) + val a05 = Literal.create(Seq(-5, 4, null, 2, -1, null), ArrayType(IntegerType, true)) + val a06 = Literal.create(Seq.empty[Int], ArrayType(IntegerType, false)) + val ab0 = Literal.create(Seq[Byte](1, 2, 3, 2), ArrayType(ByteType, containsNull = false)) + val ab1 = Literal.create(Seq[Byte](4, 2, 4), ArrayType(ByteType, containsNull = false)) + val as0 = Literal.create(Seq[Short](1, 2, 3, 2), ArrayType(ShortType, containsNull = false)) + val as1 = Literal.create(Seq[Short](4, 2, 4), ArrayType(ShortType, containsNull = false)) val a10 = Literal.create(Seq(1L, 2L, 4L), ArrayType(LongType, false)) val a11 = Literal.create(Seq(4L, 2L), ArrayType(LongType, false)) - val a12 = Literal.create(Seq(1L, 2L, 4L), ArrayType(LongType)) - val a13 = Literal.create(Seq(1L, 2L, null, 4L, 5L), ArrayType(LongType)) - val a14 = Literal.create(Seq(-5L, 4L, null, 2L, -1L), ArrayType(LongType)) - val a15 = Literal.create(Seq.empty[Long], ArrayType(LongType)) + val a12 = Literal.create(Seq(1L, 2L, 1L, 4L), ArrayType(LongType, false)) + val a13 = Literal.create(Seq(4L, 2L, 4L), ArrayType(LongType, false)) + val a14 = Literal.create(Seq(1L, 2L, null, 4L, 5L, null), ArrayType(LongType, true)) + val a15 = Literal.create(Seq(-5L, 4L, null, 2L, -1L, null), ArrayType(LongType, true)) + val a16 = Literal.create(Seq.empty[Long], ArrayType(LongType, false)) - val a20 = Literal.create(Seq("b", "a", "c"), ArrayType(StringType)) - val a21 = Literal.create(Seq("c", null, "a", "f"), ArrayType(StringType)) - val a22 = Literal.create(Seq("b", null, "a", "g"), ArrayType(StringType)) - val a23 = Literal.create(Seq("b", "a", "c"), ArrayType(StringType, false)) - val a24 = Literal.create(Seq("c", "d", "a", "f"), ArrayType(StringType, false)) + val a20 = Literal.create(Seq("b", "a", "c"), ArrayType(StringType, false)) + val a21 = Literal.create(Seq("c", "a"), ArrayType(StringType, false)) + val a22 = Literal.create(Seq("b", "a", "c", "a"), ArrayType(StringType, false)) + val a23 = Literal.create(Seq("c", null, "a", "f"), ArrayType(StringType, true)) + val a24 = Literal.create(Seq("b", null, "a", "g", null), ArrayType(StringType, true)) + val a25 = Literal.create(Seq.empty[String], ArrayType(StringType, false)) - val a30 = Literal.create(Seq(null, null), ArrayType(NullType)) + val a30 = Literal.create(Seq(null, null), ArrayType(IntegerType)) + val a31 = Literal.create(null, ArrayType(StringType)) - checkEvaluation(ArrayIntersect(a00, a01), UnsafeArrayData.fromPrimitiveArray(Array(4, 2))) - checkEvaluation(ArrayIntersect(a01, a02), Seq(4, 2)) - checkEvaluation(ArrayIntersect(a03, a04), Seq(2, null, 4)) - checkEvaluation(ArrayIntersect(a03, a05), Seq.empty) + checkEvaluation(ArrayIntersect(a00, a01), Seq(4, 2)) + checkEvaluation(ArrayIntersect(a01, a00), Seq(2, 4)) + checkEvaluation(ArrayIntersect(a02, a03), Seq(4, 2)) + checkEvaluation(ArrayIntersect(a03, a02), Seq(2, 4)) + checkEvaluation(ArrayIntersect(a04, a05), Seq(4, null, 2)) + checkEvaluation(ArrayIntersect(a02, a06), Seq.empty) + checkEvaluation(ArrayIntersect(a06, a04), Seq.empty) + checkEvaluation(ArrayIntersect(ab0, ab1), Seq[Byte](2)) + checkEvaluation(ArrayIntersect(as0, as1), Seq[Short](2)) + + checkEvaluation(ArrayIntersect(a10, a11), Seq(4L, 2L)) + checkEvaluation(ArrayIntersect(a11, a10), Seq(2L, 4L)) + checkEvaluation(ArrayIntersect(a12, a13), Seq(4L, 2L)) + checkEvaluation(ArrayIntersect(a13, a12), Seq(2L, 4L)) + checkEvaluation(ArrayIntersect(a14, a15), Seq(4L, null, 2L)) + checkEvaluation(ArrayIntersect(a12, a16), Seq.empty) + checkEvaluation(ArrayIntersect(a16, a14), Seq.empty) + + checkEvaluation(ArrayIntersect(a20, a21), Seq("c", "a")) + checkEvaluation(ArrayIntersect(a21, a20), Seq("a", "c")) + checkEvaluation(ArrayIntersect(a22, a21), Seq("c", "a")) + checkEvaluation(ArrayIntersect(a21, a22), Seq("a", "c")) + checkEvaluation(ArrayIntersect(a23, a24), Seq(null, "a")) + checkEvaluation(ArrayIntersect(a24, a23), Seq(null, "a")) + checkEvaluation(ArrayIntersect(a24, a25), Seq.empty) + checkEvaluation(ArrayIntersect(a25, a24), Seq.empty) - checkEvaluation( - ArrayIntersect(a10, a11), UnsafeArrayData.fromPrimitiveArray(Array(4L, 2L))) - checkEvaluation(ArrayIntersect(a11, a12), Seq(4L, 2L)) - checkEvaluation(ArrayIntersect(a13, a14), Seq(2L, null, 4L)) - checkEvaluation(ArrayIntersect(a13, a15), Seq.empty) + checkEvaluation(ArrayIntersect(a30, a30), Seq(null)) + checkEvaluation(ArrayIntersect(a20, a31), null) + checkEvaluation(ArrayIntersect(a31, a20), null) - checkEvaluation(ArrayIntersect(a20, a21), Seq("a", "c")) - checkEvaluation(ArrayIntersect(a21, a22), Seq(null, "a")) - checkEvaluation(ArrayIntersect(a23, a24), Seq("c", "a")) + val b0 = Literal.create( + Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2), Array[Byte](3, 4)), + ArrayType(BinaryType)) + val b1 = Literal.create( + Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](3, 4), Array[Byte](5, 6)), + ArrayType(BinaryType)) + val b2 = Literal.create( + Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](3, 4), Array[Byte](1, 2)), + ArrayType(BinaryType)) + val b3 = Literal.create(Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](3, 4), null), + ArrayType(BinaryType)) + val b4 = Literal.create(Seq[Array[Byte]](null, Array[Byte](3, 4), null), ArrayType(BinaryType)) + val b5 = Literal.create(Seq.empty, ArrayType(BinaryType)) + val arrayWithBinaryNull = Literal.create(Seq(null), ArrayType(BinaryType)) +/* + checkEvaluation(ArrayIntersect(b0, b1), Seq(Array[Byte](3, 4), Array[Byte](5, 6))) + checkEvaluation(ArrayIntersect(b1, b0), Seq(Array[Byte](5, 6), Array[Byte](3, 4))) + checkEvaluation(ArrayIntersect(b0, b2), Seq(Array[Byte](1, 2), Array[Byte](3, 4))) + checkEvaluation(ArrayIntersect(b2, b0), Seq(Array[Byte](1, 2), Array[Byte](3, 4))) + checkEvaluation(ArrayIntersect(b2, b3), Seq(Array[Byte](1, 2), Array[Byte](3, 4))) + checkEvaluation(ArrayIntersect(b3, b2), Seq(Array[Byte](1, 2), Array[Byte](3, 4))) + checkEvaluation(ArrayIntersect(b3, b4), Seq(Array[Byte](null, Array[Byte](3, 4)))) + checkEvaluation(ArrayIntersect(b4, b3), Seq(Array[Byte](Array[Byte](3, 4), null))) + checkEvaluation(ArrayIntersect(b4, b5), Seq.empty) + checkEvaluation(ArrayIntersect(b5, b4), Seq.empty) + checkEvaluation(ArrayIntersect(b4, arrayWithBinaryNull), Seq(null)) +*/ + val aa0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4), Seq[Int](1, 2)), + ArrayType(ArrayType(IntegerType))) + val aa1 = Literal.create(Seq[Seq[Int]](Seq[Int](3, 4), Seq[Int](2, 1), Seq[Int](3, 4)), + ArrayType(ArrayType(IntegerType))) + checkEvaluation(ArrayIntersect(aa0, aa1), Seq[Seq[Int]](Seq[Int](3, 4))) + checkEvaluation(ArrayIntersect(aa1, aa0), Seq[Seq[Int]](Seq[Int](3, 4))) - checkEvaluation(ArrayIntersect(a30, a30), Seq(null)) + assert(ArrayIntersect(a00, a01).dataType.asInstanceOf[ArrayType].containsNull === false) + assert(ArrayIntersect(a00, a04).dataType.asInstanceOf[ArrayType].containsNull === true) + assert(ArrayIntersect(a20, a21).dataType.asInstanceOf[ArrayType].containsNull === false) + assert(ArrayIntersect(a20, a23).dataType.asInstanceOf[ArrayType].containsNull === true) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index eb08635fc7689..5a565cb81ab8c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1655,7 +1655,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { val df2 = Seq((Array[Integer](1, 2, null, 4, 5), Array[Integer](-5, 4, null, 2, -1))) .toDF("a", "b") - val ans2 = Row(Seq(2, null, 4)) + val ans2 = Row(Seq(4, null, 2)) checkAnswer(df2.select(array_intersect($"a", $"b")), ans2) checkAnswer(df2.selectExpr("array_intersect(a, b)"), ans2) @@ -1664,7 +1664,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(df3.select(array_intersect($"a", $"b")), ans3) checkAnswer(df3.selectExpr("array_intersect(a, b)"), ans3) - val ans4 = Row(Seq(2L, null, 4L)) + val df4 = Seq( + (Array[java.lang.Long](1L, 2L, null, 4L, 5L), Array[java.lang.Long](-5L, 4L, null, 2L, -1L))) + .toDF("a", "b") + val ans4 = Row(Seq(4L, null, 2L)) checkAnswer(df4.select(array_intersect($"a", $"b")), ans4) checkAnswer(df4.selectExpr("array_intersect(a, b)"), ans4) @@ -1674,16 +1677,27 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(df5.selectExpr("array_intersect(a, b)"), ans5) val df6 = Seq((null, null)).toDF("a", "b") - val ans6 = Row(null) - checkAnswer(df6.select(array_intersect($"a", $"b")), ans6) - checkAnswer(df6.selectExpr("array_intersect(a, b)"), ans6) + intercept[AnalysisException] { + df6.select(array_intersect($"a", $"b")) + } + intercept[AnalysisException] { + df6.selectExpr("array_intersect(a, b)") + } + + val df7 = Seq((Array(1), Array("a"))).toDF("a", "b") + intercept[AnalysisException] { + df7.select(array_intersect($"a", $"b")) + } + intercept[AnalysisException] { + df7.selectExpr("array_intersect(a, b)") + } - val df0 = Seq((Array(1), Array("a"))).toDF("a", "b") + val df8 = Seq((null, Array("a"))).toDF("a", "b") intercept[AnalysisException] { - df0.select(array_intersect($"a", $"b")) + df8.select(array_intersect($"a", $"b")) } intercept[AnalysisException] { - df0.selectExpr("array_intersect(a, b)") + df8.selectExpr("array_intersect(a, b)") } } From 4ec43b1f842c46c50b806b52106d091c6968d32b Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 10 Jul 2018 12:39:43 +0100 Subject: [PATCH 05/21] add missing file --- python/pyspark/sql/functions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 42fca146f3ba9..0e1c1162ee01a 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2038,7 +2038,7 @@ def array_distinct(col): def array_intersect(col1, col2): """ Collection function: returns an array of the elements in the intersection of col1 and col2, - without duplicates. The order of elements in the result is not determined. + without duplicates. :param col1: name of column containing array :param col2: name of column containing array @@ -2046,7 +2046,7 @@ def array_intersect(col1, col2): >>> from pyspark.sql import Row >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])]) >>> df.select(array_intersect(df.c1, df.c2)).collect() - [Row(array_intersect(c1, c2)=[u'a', u'c'])] + [Row(array_intersect(c1, c2)=[u'c', u'a'])] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.array_intersect(_to_java_column(col1), _to_java_column(col2))) From 2c303514f382eb51fcf6986a036de121ae811c38 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 12 Jul 2018 19:07:23 +0100 Subject: [PATCH 06/21] fix test failure --- .../org/apache/spark/util/collection/OpenHashSetSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala index 6e9fb9f6f71f4..285f2ae063924 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala @@ -150,7 +150,7 @@ class OpenHashSetSuite extends SparkFunSuite with Matchers { assert(set.size === 3) assert(set.contains(10L)) assert(set.contains(50L)) - assert(!set.contains(999L)) + assert(set.contains(999L)) assert(!set.contains(10000L)) } From 74103cee5fb7e77a4a25545b2cada7590c06de40 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 12 Jul 2018 20:16:32 +0100 Subject: [PATCH 07/21] update --- .../catalyst/expressions/CollectionExpressionsSuite.scala | 7 +++---- .../org/apache/spark/sql/DataFrameFunctionsSuite.scala | 4 ++-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 2945b055a647b..4082a93c86422 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1695,19 +1695,18 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val b4 = Literal.create(Seq[Array[Byte]](null, Array[Byte](3, 4), null), ArrayType(BinaryType)) val b5 = Literal.create(Seq.empty, ArrayType(BinaryType)) val arrayWithBinaryNull = Literal.create(Seq(null), ArrayType(BinaryType)) -/* checkEvaluation(ArrayIntersect(b0, b1), Seq(Array[Byte](3, 4), Array[Byte](5, 6))) checkEvaluation(ArrayIntersect(b1, b0), Seq(Array[Byte](5, 6), Array[Byte](3, 4))) checkEvaluation(ArrayIntersect(b0, b2), Seq(Array[Byte](1, 2), Array[Byte](3, 4))) checkEvaluation(ArrayIntersect(b2, b0), Seq(Array[Byte](1, 2), Array[Byte](3, 4))) checkEvaluation(ArrayIntersect(b2, b3), Seq(Array[Byte](1, 2), Array[Byte](3, 4))) checkEvaluation(ArrayIntersect(b3, b2), Seq(Array[Byte](1, 2), Array[Byte](3, 4))) - checkEvaluation(ArrayIntersect(b3, b4), Seq(Array[Byte](null, Array[Byte](3, 4)))) - checkEvaluation(ArrayIntersect(b4, b3), Seq(Array[Byte](Array[Byte](3, 4), null))) + checkEvaluation(ArrayIntersect(b3, b4), Seq(null, Array[Byte](3, 4))) + checkEvaluation(ArrayIntersect(b4, b3), Seq(Array[Byte](3, 4), null)) checkEvaluation(ArrayIntersect(b4, b5), Seq.empty) checkEvaluation(ArrayIntersect(b5, b4), Seq.empty) checkEvaluation(ArrayIntersect(b4, arrayWithBinaryNull), Seq(null)) -*/ + val aa0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4), Seq[Int](1, 2)), ArrayType(ArrayType(IntegerType))) val aa1 = Literal.create(Seq[Seq[Int]](Seq[Int](3, 4), Seq[Int](2, 1), Seq[Int](3, 4)), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 5a565cb81ab8c..7003892d1c97d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1173,8 +1173,8 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } test("concat function - arrays") { - val nseqi: Seq[Int] = null - val nseqs: Seq[String] = null + val nseqi : Seq[Int] = null + val nseqs : Seq[String] = null val df = Seq( (Seq(1), Seq(2, 3), Seq(5L, 6L), nseqi, Seq("a", "b", "c"), Seq("d", "e"), Seq("f"), nseqs), (Seq(1, 0), Seq.empty[Int], Seq(2L), nseqi, Seq("a"), Seq.empty[String], Seq(null), nseqs) From 387d0ca428ae5d1c4c0624d06e0b6f35c91885de Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 12 Jul 2018 20:45:32 +0100 Subject: [PATCH 08/21] fix compilation failure --- .../expressions/collectionOperations.scala | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 171804e4e8a4b..a42c4ef1bcb15 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3963,14 +3963,15 @@ object ArrayUnion { })) new GenericArrayData(arrayBuffer) } +} - /** - * Returns an array of the elements in the intersect of x and y, without duplicates - */ - @ExpressionDescription( - usage = """ - _FUNC_(array1, array2) - Returns an array of the elements in the intersection of array1 and - array2, without duplicates. +/** + * Returns an array of the elements in the intersect of x and y, without duplicates + */ +@ExpressionDescription( + usage = """ + _FUNC_(array1, array2) - Returns an array of the elements in the intersection of array1 and + array2, without duplicates. """, examples = """ Examples:Fun @@ -4067,7 +4068,7 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetL val elements = evalIntLongPrimitiveType(array1, array2, null, false) // allocate result array hsInt = new OpenHashSet[Int] - val resultArray = if (UnsafeArrayData.canUseGenericArrayData( + val resultArray = if (UnsafeArrayData.shouldUseGenericArrayData( IntegerType.defaultSize, elements)) { new GenericArrayData(new Array[Any](elements)) } else { @@ -4084,7 +4085,7 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetL val elements = evalIntLongPrimitiveType(array1, array2, null, true) // allocate result array hsLong = new OpenHashSet[Long] - val resultArray = if (UnsafeArrayData.canUseGenericArrayData( + val resultArray = if (UnsafeArrayData.shouldUseGenericArrayData( LongType.defaultSize, elements)) { new GenericArrayData(new Array[Any](elements)) } else { From e374aa9c94322e9abb6f8baaf15c73dc28820f84 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 13 Jul 2018 10:10:39 +0100 Subject: [PATCH 09/21] fix test failures --- .../expressions/collectionOperations.scala | 24 +++++++++---------- .../CollectionExpressionsSuite.scala | 18 +++++++------- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index a42c4ef1bcb15..beb86eb985d22 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -4530,33 +4530,33 @@ object ArrayIntersect { val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] var alreadySeenNull = false var i = 0 - while (i < array1.numElements()) { + while (i < array2.numElements()) { var found = false - val elem1 = array1.get(i, elementType) - if (array1.isNullAt(i)) { + val elem2 = array2.get(i, elementType) + if (array2.isNullAt(i)) { if (!alreadySeenNull) { var j = 0 - while (!found && j < array2.numElements()) { - if (array2.isNullAt(j)) { + while (!found && j < array1.numElements()) { + if (array1.isNullAt(j)) { found = true } j += 1 } - // array2 is scaned only once for null element + // array1 is scaned only once for null element alreadySeenNull = true } } else { var j = 0 - while (!found && j < array2.numElements()) { - if (!array2.isNullAt(j)) { - val elem2 = array2.get(j, elementType) + while (!found && j < array1.numElements()) { + if (!array1.isNullAt(j)) { + val elem1 = array1.get(j, elementType) if (ordering.equiv(elem1, elem2)) { - // check whether elem2 is already stored in arrayBuffer + // check whether elem1 is already stored in arrayBuffer var foundArrayBuffer = false var k = 0 while (!foundArrayBuffer && k < arrayBuffer.size) { val va = arrayBuffer(k) - if (va != null && ordering.equiv(va, elem2)) { + if (va != null && ordering.equiv(va, elem1)) { foundArrayBuffer = true } k += 1 @@ -4570,7 +4570,7 @@ object ArrayIntersect { } } if (found) { - arrayBuffer += elem1 + arrayBuffer += elem2 } i += 1 } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 4082a93c86422..2d21dbdd14c57 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1695,17 +1695,17 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val b4 = Literal.create(Seq[Array[Byte]](null, Array[Byte](3, 4), null), ArrayType(BinaryType)) val b5 = Literal.create(Seq.empty, ArrayType(BinaryType)) val arrayWithBinaryNull = Literal.create(Seq(null), ArrayType(BinaryType)) - checkEvaluation(ArrayIntersect(b0, b1), Seq(Array[Byte](3, 4), Array[Byte](5, 6))) - checkEvaluation(ArrayIntersect(b1, b0), Seq(Array[Byte](5, 6), Array[Byte](3, 4))) - checkEvaluation(ArrayIntersect(b0, b2), Seq(Array[Byte](1, 2), Array[Byte](3, 4))) - checkEvaluation(ArrayIntersect(b2, b0), Seq(Array[Byte](1, 2), Array[Byte](3, 4))) - checkEvaluation(ArrayIntersect(b2, b3), Seq(Array[Byte](1, 2), Array[Byte](3, 4))) - checkEvaluation(ArrayIntersect(b3, b2), Seq(Array[Byte](1, 2), Array[Byte](3, 4))) - checkEvaluation(ArrayIntersect(b3, b4), Seq(null, Array[Byte](3, 4))) - checkEvaluation(ArrayIntersect(b4, b3), Seq(Array[Byte](3, 4), null)) + checkEvaluation(ArrayIntersect(b0, b1), Seq[Array[Byte]](Array[Byte](3, 4), Array[Byte](5, 6))) + checkEvaluation(ArrayIntersect(b1, b0), Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](3, 4))) + checkEvaluation(ArrayIntersect(b0, b2), Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](3, 4))) + checkEvaluation(ArrayIntersect(b2, b0), Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](3, 4))) + checkEvaluation(ArrayIntersect(b2, b3), Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](3, 4))) + checkEvaluation(ArrayIntersect(b3, b2), Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](3, 4))) + checkEvaluation(ArrayIntersect(b3, b4), Seq[Array[Byte]](null, Array[Byte](3, 4))) + checkEvaluation(ArrayIntersect(b4, b3), Seq[Array[Byte]](Array[Byte](3, 4), null)) checkEvaluation(ArrayIntersect(b4, b5), Seq.empty) checkEvaluation(ArrayIntersect(b5, b4), Seq.empty) - checkEvaluation(ArrayIntersect(b4, arrayWithBinaryNull), Seq(null)) + checkEvaluation(ArrayIntersect(b4, arrayWithBinaryNull), Seq[Array[Byte]](null)) val aa0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4), Seq[Int](1, 2)), ArrayType(ArrayType(IntegerType))) From e31f7e67acc4bd41adeff59f237bc7d526aeae07 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sat, 14 Jul 2018 07:51:13 +0100 Subject: [PATCH 10/21] minor refactoring --- .../catalyst/expressions/collectionOperations.scala | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index beb86eb985d22..eb40f612edacd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -4537,9 +4537,7 @@ object ArrayIntersect { if (!alreadySeenNull) { var j = 0 while (!found && j < array1.numElements()) { - if (array1.isNullAt(j)) { - found = true - } + found = array1.isNullAt(j) j += 1 } // array1 is scaned only once for null element @@ -4556,14 +4554,10 @@ object ArrayIntersect { var k = 0 while (!foundArrayBuffer && k < arrayBuffer.size) { val va = arrayBuffer(k) - if (va != null && ordering.equiv(va, elem1)) { - foundArrayBuffer = true - } + foundArrayBuffer = (va != null) && ordering.equiv(va, elem1) k += 1 } - if (!foundArrayBuffer) { - found = true - } + found = !foundArrayBuffer } } j += 1 From 9818ba91f51fffd2bc15054c18db76e1e182ca38 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 16 Jul 2018 08:33:46 +0100 Subject: [PATCH 11/21] update --- .../spark/util/collection/OpenHashSet.scala | 4 +- .../util/collection/OpenHashSetSuite.scala | 50 ++++++++++++++++--- .../expressions/collectionOperations.scala | 4 ++ .../CollectionExpressionsSuite.scala | 6 ++- 4 files changed, 54 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala index 739c862ed60e4..077bc000e2bb2 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala @@ -122,8 +122,8 @@ class OpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag]( * Remove an element from the set. If an element does not exists in the set, nothing is done. */ def remove(k: T): Unit = { - val pos = hashcode(hasher.hash(k)) & _mask - if (_bitset.get(pos)) { + val pos = getPos(k) + if (pos != INVALID_POS && _bitset.get(pos)) { _bitset.unset(pos) _size -= 1 } diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala index 285f2ae063924..73385f50dd313 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala @@ -74,25 +74,44 @@ class OpenHashSetSuite extends SparkFunSuite with Matchers { assert(set.contains(999)) assert(!set.contains(10000)) + set.add(1132) // Cause hash contention with 999 + assert(set.size === 4) + assert(set.contains(10)) + assert(set.contains(50)) + assert(set.contains(999)) + assert(set.contains(1132)) + assert(!set.contains(10000)) + + set.remove(1132) + assert(set.size === 3) + assert(set.contains(10)) + assert(set.contains(50)) + assert(set.contains(999)) + assert(!set.contains(1132)) + assert(!set.contains(10000)) + set.remove(999) assert(set.size === 2) assert(set.contains(10)) assert(set.contains(50)) assert(!set.contains(999)) + assert(!set.contains(1132)) assert(!set.contains(10000)) - set.add(999) + set.add(1132) assert(set.size === 3) assert(set.contains(10)) assert(set.contains(50)) - assert(set.contains(999)) + assert(!set.contains(999)) + assert(set.contains(1132)) assert(!set.contains(10000)) set.remove(10000) assert(set.size === 3) assert(set.contains(10)) assert(set.contains(50)) - assert(set.contains(999)) + assert(!set.contains(999)) + assert(set.contains(1132)) assert(!set.contains(10000)) } @@ -132,25 +151,44 @@ class OpenHashSetSuite extends SparkFunSuite with Matchers { assert(set.contains(999L)) assert(!set.contains(10000L)) + set.add(1132L) // Cause hash contention with 999L + assert(set.size === 4) + assert(set.contains(10L)) + assert(set.contains(50L)) + assert(set.contains(999L)) + assert(set.contains(1132L)) + assert(!set.contains(10000L)) + + set.remove(1132) + assert(set.size === 3) + assert(set.contains(10L)) + assert(set.contains(50L)) + assert(set.contains(999L)) + assert(!set.contains(1132L)) + assert(!set.contains(10000L)) + set.remove(999L) assert(set.size === 2) assert(set.contains(10L)) assert(set.contains(50L)) assert(!set.contains(999L)) + assert(!set.contains(1132L)) assert(!set.contains(10000L)) - set.add(999L) + set.add(1132L) assert(set.size === 3) assert(set.contains(10L)) assert(set.contains(50L)) - assert(set.contains(999L)) + assert(!set.contains(999L)) + assert(set.contains(1132L)) assert(!set.contains(10000L)) set.remove(10000L) assert(set.size === 3) assert(set.contains(10L)) assert(set.contains(50L)) - assert(set.contains(999L)) + assert(!set.contains(999L)) + assert(set.contains(1132L)) assert(!set.contains(10000L)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index eb40f612edacd..0af92aeea2394 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3980,6 +3980,10 @@ object ArrayUnion { """, since = "2.4.0") case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetLike { + override def dataType: DataType = ArrayType(elementType, + left.dataType.asInstanceOf[ArrayType].containsNull && + right.dataType.asInstanceOf[ArrayType].containsNull) + var hsInt: OpenHashSet[Int] = _ var hsLong: OpenHashSet[Long] = _ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 2d21dbdd14c57..e2ea3912462ba 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1654,6 +1654,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayIntersect(a01, a00), Seq(2, 4)) checkEvaluation(ArrayIntersect(a02, a03), Seq(4, 2)) checkEvaluation(ArrayIntersect(a03, a02), Seq(2, 4)) + checkEvaluation(ArrayIntersect(a00, a04), Seq(1, 2, 4)) checkEvaluation(ArrayIntersect(a04, a05), Seq(4, null, 2)) checkEvaluation(ArrayIntersect(a02, a06), Seq.empty) checkEvaluation(ArrayIntersect(a06, a04), Seq.empty) @@ -1715,8 +1716,9 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayIntersect(aa1, aa0), Seq[Seq[Int]](Seq[Int](3, 4))) assert(ArrayIntersect(a00, a01).dataType.asInstanceOf[ArrayType].containsNull === false) - assert(ArrayIntersect(a00, a04).dataType.asInstanceOf[ArrayType].containsNull === true) + assert(ArrayIntersect(a00, a04).dataType.asInstanceOf[ArrayType].containsNull === false) + assert(ArrayIntersect(a04, a05).dataType.asInstanceOf[ArrayType].containsNull === true) assert(ArrayIntersect(a20, a21).dataType.asInstanceOf[ArrayType].containsNull === false) - assert(ArrayIntersect(a20, a23).dataType.asInstanceOf[ArrayType].containsNull === true) + assert(ArrayIntersect(a23, a24).dataType.asInstanceOf[ArrayType].containsNull === true) } } From 352743f0e0fe18c88c37f67f535fff12351bae06 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 16 Jul 2018 19:42:18 +0100 Subject: [PATCH 12/21] address review comment --- .../spark/util/collection/OpenHashSet.scala | 48 ++++++++----- .../util/collection/OpenHashSetSuite.scala | 69 ++++++++++++++++--- 2 files changed, 93 insertions(+), 24 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala index 077bc000e2bb2..851bab0a44799 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala @@ -89,9 +89,13 @@ class OpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag]( protected var _capacity = nextPowerOf2(initialCapacity) protected var _mask = _capacity - 1 protected var _size = 0 + protected var _occupied = 0 protected var _growThreshold = (loadFactor * _capacity).toInt + def g: Int = _growThreshold + def o: Int = _occupied protected var _bitset = new BitSet(_capacity) + protected var _bitsetDeleted: BitSet = null def getBitSet: BitSet = _bitset @@ -122,9 +126,13 @@ class OpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag]( * Remove an element from the set. If an element does not exists in the set, nothing is done. */ def remove(k: T): Unit = { + if (_bitsetDeleted == null) { + _bitsetDeleted = new BitSet(_capacity) + } val pos = getPos(k) - if (pos != INVALID_POS && _bitset.get(pos)) { + if (pos != INVALID_POS) { _bitset.unset(pos) + _bitsetDeleted.set(pos) _size -= 1 } } @@ -152,19 +160,24 @@ class OpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag]( var delta = 1 while (true) { if (!_bitset.get(pos)) { - // This is a new key. - _data(pos) = k - _bitset.set(pos) - _size += 1 - return pos | NONEXISTENCE_MASK + if (_bitsetDeleted == null || !_bitsetDeleted.get(pos)) { + // This is a new key. + _data(pos) = k + _bitset.set(pos) + if (_bitsetDeleted != null) { + _bitsetDeleted.unset(pos) + } + _size += 1 + _occupied += 1 + return pos | NONEXISTENCE_MASK + } } else if (_data(pos) == k) { // Found an existing key. return pos - } else { - // quadratic probing with values increase by 1, 2, 3, ... - pos = (pos + delta) & _mask - delta += 1 } + // quadratic probing with values increase by 1, 2, 3, ... + pos = (pos + delta) & _mask + delta += 1 } throw new RuntimeException("Should never reach here.") } @@ -178,7 +191,7 @@ class OpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag]( * to a new position (in the new data array). */ def rehashIfNeeded(k: T, allocateFunc: (Int) => Unit, moveFunc: (Int, Int) => Unit) { - if (_size > _growThreshold) { + if (_occupied > _growThreshold) { rehash(k, allocateFunc, moveFunc) } } @@ -191,14 +204,15 @@ class OpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag]( var delta = 1 while (true) { if (!_bitset.get(pos)) { - return INVALID_POS + if (_bitsetDeleted == null || !_bitsetDeleted.get(pos)) { + return INVALID_POS + } } else if (k == _data(pos)) { return pos - } else { - // quadratic probing with values increase by 1, 2, 3, ... - pos = (pos + delta) & _mask - delta += 1 } + // quadratic probing with values increase by 1, 2, 3, ... + pos = (pos + delta) & _mask + delta += 1 } throw new RuntimeException("Should never reach here.") } @@ -219,6 +233,7 @@ class OpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag]( /** Return the value at the specified position. */ def getValueSafe(pos: Int): T = { assert(_bitset.get(pos)) + assert(_bitsetDeleted == null || !_bitsetDeleted.get(pos)) _data(pos) } @@ -274,6 +289,7 @@ class OpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag]( } _bitset = newBitset + _bitsetDeleted = null _data = newData _capacity = newCapacity _mask = newMask diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala index 73385f50dd313..bb36d334fd0d2 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala @@ -82,15 +82,15 @@ class OpenHashSetSuite extends SparkFunSuite with Matchers { assert(set.contains(1132)) assert(!set.contains(10000)) - set.remove(1132) + set.remove(999) assert(set.size === 3) assert(set.contains(10)) assert(set.contains(50)) - assert(set.contains(999)) - assert(!set.contains(1132)) + assert(!set.contains(999)) + assert(set.contains(1132)) assert(!set.contains(10000)) - set.remove(999) + set.remove(1132) assert(set.size === 2) assert(set.contains(10)) assert(set.contains(50)) @@ -106,6 +106,22 @@ class OpenHashSetSuite extends SparkFunSuite with Matchers { assert(set.contains(1132)) assert(!set.contains(10000)) + set.add(999) + assert(set.size === 4) + assert(set.contains(10)) + assert(set.contains(50)) + assert(set.contains(999)) + assert(set.contains(1132)) + assert(!set.contains(10000)) + + set.remove(999) + assert(set.size === 3) + assert(set.contains(10)) + assert(set.contains(50)) + assert(!set.contains(999)) + assert(set.contains(1132)) + assert(!set.contains(10000)) + set.remove(10000) assert(set.size === 3) assert(set.contains(10)) @@ -159,15 +175,15 @@ class OpenHashSetSuite extends SparkFunSuite with Matchers { assert(set.contains(1132L)) assert(!set.contains(10000L)) - set.remove(1132) + set.remove(999L) assert(set.size === 3) assert(set.contains(10L)) assert(set.contains(50L)) - assert(set.contains(999L)) - assert(!set.contains(1132L)) + assert(!set.contains(999L)) + assert(set.contains(1132L)) assert(!set.contains(10000L)) - set.remove(999L) + set.remove(1132L) assert(set.size === 2) assert(set.contains(10L)) assert(set.contains(50L)) @@ -183,6 +199,22 @@ class OpenHashSetSuite extends SparkFunSuite with Matchers { assert(set.contains(1132L)) assert(!set.contains(10000L)) + set.add(999L) + assert(set.size === 4) + assert(set.contains(10L)) + assert(set.contains(50L)) + assert(set.contains(999L)) + assert(set.contains(1132L)) + assert(!set.contains(10000L)) + + set.remove(999L) + assert(set.size === 3) + assert(set.contains(10L)) + assert(set.contains(50L)) + assert(!set.contains(999L)) + assert(set.contains(1132L)) + assert(!set.contains(10000L)) + set.remove(10000L) assert(set.size === 3) assert(set.contains(10L)) @@ -352,6 +384,27 @@ class OpenHashSetSuite extends SparkFunSuite with Matchers { assert(set.capacity > 1000) } + test("growth with remove") { + val loadFactor = 0.7 + val set = new OpenHashSet[Int](8, loadFactor) + for (i <- 1 to 5) { + set.add(i) + assert(set.contains(i)) + set.remove(i) + assert(!set.contains(i)) + } + assert(set.size == 0) + assert(set.capacity == 8) + + // resize should occur + set.add(6) + assert(set.contains(6)) + set.remove(6) + assert(!set.contains(6)) + assert(set.size == 0) + assert(set.capacity > 8) + } + test("SPARK-18200 Support zero as an initial set size") { val set = new OpenHashSet[Long](0) assert(set.size === 0) From 89a828b702878cca8c96e5ef564519860467c3d1 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sat, 21 Jul 2018 20:56:40 +0100 Subject: [PATCH 13/21] avoid to use OpenHashSet.remove method --- .../spark/util/collection/OpenHashSet.scala | 59 ++----- .../util/collection/OpenHashSetSuite.scala | 154 ------------------ .../expressions/collectionOperations.scala | 90 +++++----- .../CollectionExpressionsSuite.scala | 2 +- .../spark/sql/DataFrameFunctionsSuite.scala | 10 ++ 5 files changed, 76 insertions(+), 239 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala index 851bab0a44799..8be050ecfba87 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala @@ -89,13 +89,9 @@ class OpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag]( protected var _capacity = nextPowerOf2(initialCapacity) protected var _mask = _capacity - 1 protected var _size = 0 - protected var _occupied = 0 protected var _growThreshold = (loadFactor * _capacity).toInt - def g: Int = _growThreshold - def o: Int = _occupied protected var _bitset = new BitSet(_capacity) - protected var _bitsetDeleted: BitSet = null def getBitSet: BitSet = _bitset @@ -122,21 +118,6 @@ class OpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag]( rehashIfNeeded(k, grow, move) } - /** - * Remove an element from the set. If an element does not exists in the set, nothing is done. - */ - def remove(k: T): Unit = { - if (_bitsetDeleted == null) { - _bitsetDeleted = new BitSet(_capacity) - } - val pos = getPos(k) - if (pos != INVALID_POS) { - _bitset.unset(pos) - _bitsetDeleted.set(pos) - _size -= 1 - } - } - def union(other: OpenHashSet[T]): OpenHashSet[T] = { val iterator = other.iterator while (iterator.hasNext) { @@ -160,24 +141,19 @@ class OpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag]( var delta = 1 while (true) { if (!_bitset.get(pos)) { - if (_bitsetDeleted == null || !_bitsetDeleted.get(pos)) { - // This is a new key. - _data(pos) = k - _bitset.set(pos) - if (_bitsetDeleted != null) { - _bitsetDeleted.unset(pos) - } - _size += 1 - _occupied += 1 - return pos | NONEXISTENCE_MASK - } + // This is a new key. + _data(pos) = k + _bitset.set(pos) + _size += 1 + return pos | NONEXISTENCE_MASK } else if (_data(pos) == k) { // Found an existing key. return pos + } else { + // quadratic probing with values increase by 1, 2, 3, ... + pos = (pos + delta) & _mask + delta += 1 } - // quadratic probing with values increase by 1, 2, 3, ... - pos = (pos + delta) & _mask - delta += 1 } throw new RuntimeException("Should never reach here.") } @@ -191,7 +167,7 @@ class OpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag]( * to a new position (in the new data array). */ def rehashIfNeeded(k: T, allocateFunc: (Int) => Unit, moveFunc: (Int, Int) => Unit) { - if (_occupied > _growThreshold) { + if (_size > _growThreshold) { rehash(k, allocateFunc, moveFunc) } } @@ -204,15 +180,14 @@ class OpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag]( var delta = 1 while (true) { if (!_bitset.get(pos)) { - if (_bitsetDeleted == null || !_bitsetDeleted.get(pos)) { - return INVALID_POS - } + return INVALID_POS } else if (k == _data(pos)) { return pos + } else { + // quadratic probing with values increase by 1, 2, 3, ... + pos = (pos + delta) & _mask + delta += 1 } - // quadratic probing with values increase by 1, 2, 3, ... - pos = (pos + delta) & _mask - delta += 1 } throw new RuntimeException("Should never reach here.") } @@ -233,7 +208,6 @@ class OpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag]( /** Return the value at the specified position. */ def getValueSafe(pos: Int): T = { assert(_bitset.get(pos)) - assert(_bitsetDeleted == null || !_bitsetDeleted.get(pos)) _data(pos) } @@ -289,7 +263,6 @@ class OpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag]( } _bitset = newBitset - _bitsetDeleted = null _data = newData _capacity = newCapacity _mask = newMask @@ -303,7 +276,7 @@ class OpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag]( private def nextPowerOf2(n: Int): Int = { if (n == 0) { - 1 + 2 } else { val highBit = Integer.highestOneBit(n) if (highBit == n) n else highBit << 1 diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala index bb36d334fd0d2..b887f937a9da9 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala @@ -73,62 +73,6 @@ class OpenHashSetSuite extends SparkFunSuite with Matchers { assert(set.contains(50)) assert(set.contains(999)) assert(!set.contains(10000)) - - set.add(1132) // Cause hash contention with 999 - assert(set.size === 4) - assert(set.contains(10)) - assert(set.contains(50)) - assert(set.contains(999)) - assert(set.contains(1132)) - assert(!set.contains(10000)) - - set.remove(999) - assert(set.size === 3) - assert(set.contains(10)) - assert(set.contains(50)) - assert(!set.contains(999)) - assert(set.contains(1132)) - assert(!set.contains(10000)) - - set.remove(1132) - assert(set.size === 2) - assert(set.contains(10)) - assert(set.contains(50)) - assert(!set.contains(999)) - assert(!set.contains(1132)) - assert(!set.contains(10000)) - - set.add(1132) - assert(set.size === 3) - assert(set.contains(10)) - assert(set.contains(50)) - assert(!set.contains(999)) - assert(set.contains(1132)) - assert(!set.contains(10000)) - - set.add(999) - assert(set.size === 4) - assert(set.contains(10)) - assert(set.contains(50)) - assert(set.contains(999)) - assert(set.contains(1132)) - assert(!set.contains(10000)) - - set.remove(999) - assert(set.size === 3) - assert(set.contains(10)) - assert(set.contains(50)) - assert(!set.contains(999)) - assert(set.contains(1132)) - assert(!set.contains(10000)) - - set.remove(10000) - assert(set.size === 3) - assert(set.contains(10)) - assert(set.contains(50)) - assert(!set.contains(999)) - assert(set.contains(1132)) - assert(!set.contains(10000)) } test("primitive long") { @@ -166,62 +110,6 @@ class OpenHashSetSuite extends SparkFunSuite with Matchers { assert(set.contains(50L)) assert(set.contains(999L)) assert(!set.contains(10000L)) - - set.add(1132L) // Cause hash contention with 999L - assert(set.size === 4) - assert(set.contains(10L)) - assert(set.contains(50L)) - assert(set.contains(999L)) - assert(set.contains(1132L)) - assert(!set.contains(10000L)) - - set.remove(999L) - assert(set.size === 3) - assert(set.contains(10L)) - assert(set.contains(50L)) - assert(!set.contains(999L)) - assert(set.contains(1132L)) - assert(!set.contains(10000L)) - - set.remove(1132L) - assert(set.size === 2) - assert(set.contains(10L)) - assert(set.contains(50L)) - assert(!set.contains(999L)) - assert(!set.contains(1132L)) - assert(!set.contains(10000L)) - - set.add(1132L) - assert(set.size === 3) - assert(set.contains(10L)) - assert(set.contains(50L)) - assert(!set.contains(999L)) - assert(set.contains(1132L)) - assert(!set.contains(10000L)) - - set.add(999L) - assert(set.size === 4) - assert(set.contains(10L)) - assert(set.contains(50L)) - assert(set.contains(999L)) - assert(set.contains(1132L)) - assert(!set.contains(10000L)) - - set.remove(999L) - assert(set.size === 3) - assert(set.contains(10L)) - assert(set.contains(50L)) - assert(!set.contains(999L)) - assert(set.contains(1132L)) - assert(!set.contains(10000L)) - - set.remove(10000L) - assert(set.size === 3) - assert(set.contains(10L)) - assert(set.contains(50L)) - assert(!set.contains(999L)) - assert(set.contains(1132L)) - assert(!set.contains(10000L)) } test("primitive float") { @@ -333,27 +221,6 @@ class OpenHashSetSuite extends SparkFunSuite with Matchers { assert(set.contains(50.toString)) assert(set.contains(999.toString)) assert(!set.contains(10000.toString)) - - set.remove(999.toString) - assert(set.size === 2) - assert(set.contains(10.toString)) - assert(set.contains(50.toString)) - assert(!set.contains(999.toString)) - assert(!set.contains(10000.toString)) - - set.add(999.toString) - assert(set.size === 3) - assert(set.contains(10.toString)) - assert(set.contains(50.toString)) - assert(set.contains(999.toString)) - assert(!set.contains(10000.toString)) - - set.remove(10000.toString) - assert(set.size === 3) - assert(set.contains(10.toString)) - assert(set.contains(50.toString)) - assert(set.contains(999.toString)) - assert(!set.contains(10000.toString)) } test("non-primitive set growth") { @@ -384,27 +251,6 @@ class OpenHashSetSuite extends SparkFunSuite with Matchers { assert(set.capacity > 1000) } - test("growth with remove") { - val loadFactor = 0.7 - val set = new OpenHashSet[Int](8, loadFactor) - for (i <- 1 to 5) { - set.add(i) - assert(set.contains(i)) - set.remove(i) - assert(!set.contains(i)) - } - assert(set.size == 0) - assert(set.capacity == 8) - - // resize should occur - set.add(6) - assert(set.contains(6)) - set.remove(6) - assert(!set.contains(6)) - assert(set.size == 0) - assert(set.capacity > 8) - } - test("SPARK-18200 Support zero as an initial set size") { val set = new OpenHashSet[Long](0) assert(set.size === 0) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 0af92aeea2394..313e85958cd25 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3981,19 +3981,21 @@ object ArrayUnion { since = "2.4.0") case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetLike { override def dataType: DataType = ArrayType(elementType, - left.dataType.asInstanceOf[ArrayType].containsNull && + left.dataType.asInstanceOf[ArrayType].containsNull || right.dataType.asInstanceOf[ArrayType].containsNull) var hsInt: OpenHashSet[Int] = _ + var hsResultInt: OpenHashSet[Int] = _ var hsLong: OpenHashSet[Long] = _ + var hsResultLong: OpenHashSet[Long] = _ def assignInt(array: ArrayData, idx: Int, resultArray: ArrayData, pos: Int): Boolean = { val elem = array.getInt(idx) - if (hsInt.contains(elem)) { + if (hsInt.contains(elem) && !hsResultInt.contains(elem)) { if (resultArray != null) { resultArray.setInt(pos, elem) } - hsInt.remove(elem) + hsResultInt.add(elem) true } else { false @@ -4002,11 +4004,11 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetL def assignLong(array: ArrayData, idx: Int, resultArray: ArrayData, pos: Int): Boolean = { val elem = array.getLong(idx) - if (hsLong.contains(elem)) { + if (hsLong.contains(elem) && !hsResultLong.contains(elem)) { if (resultArray != null) { resultArray.setLong(pos, elem) } - hsLong.remove(elem) + hsResultLong.add(elem) true } else { false @@ -4017,32 +4019,37 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetL array1: ArrayData, array2: ArrayData, resultArray: ArrayData, - isLongType: Boolean): Int = { + initFoundNullElement: Boolean, + isLongType: Boolean): (Int, Boolean) = { // store elements into resultArray - var foundNullElement = false var i = 0 - while (i < array1.numElements()) { - if (array1.isNullAt(i)) { - foundNullElement = true - } else { - val assigned = if (!isLongType) { - hsInt.add(array1.getInt(i)) + var foundNullElement = initFoundNullElement + if (resultArray == null) { + // hsInt or hsLong is updated only once since it is not changed + while (i < array1.numElements()) { + if (array1.isNullAt(i)) { + foundNullElement = true } else { - hsLong.add(array1.getLong(i)) + val assigned = if (!isLongType) { + hsInt.add(array1.getInt(i)) + } else { + hsLong.add(array1.getLong(i)) + } } + i += 1 } - i += 1 } var pos = 0 i = 0 + var foundNullElementForResult = foundNullElement while (i < array2.numElements()) { if (array2.isNullAt(i)) { - if (foundNullElement) { + if (foundNullElementForResult) { if (resultArray != null) { resultArray.setNullAt(pos) } pos += 1 - foundNullElement = false + foundNullElementForResult = false } } else { val assigned = if (!isLongType) { @@ -4056,7 +4063,7 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetL } i += 1 } - pos + (pos, foundNullElement) } override def nullSafeEval(input1: Any, input2: Any): Any = { @@ -4069,9 +4076,11 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetL // avoid boxing of primitive int array elements // calculate result array size hsInt = new OpenHashSet[Int] - val elements = evalIntLongPrimitiveType(array1, array2, null, false) + hsResultInt = new OpenHashSet[Int] + val (elements, foundNullElement) = + evalIntLongPrimitiveType(array1, array2, null, false, false) // allocate result array - hsInt = new OpenHashSet[Int] + hsResultInt = new OpenHashSet[Int] val resultArray = if (UnsafeArrayData.shouldUseGenericArrayData( IntegerType.defaultSize, elements)) { new GenericArrayData(new Array[Any](elements)) @@ -4080,15 +4089,17 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetL Platform.INT_ARRAY_OFFSET, elements, IntegerType.defaultSize) } // assign elements into the result array - evalIntLongPrimitiveType(array1, array2, resultArray, false) + evalIntLongPrimitiveType(array1, array2, resultArray, foundNullElement, false) resultArray case LongType => // avoid boxing of primitive long array elements // calculate result array size hsLong = new OpenHashSet[Long] - val elements = evalIntLongPrimitiveType(array1, array2, null, true) + hsResultLong = new OpenHashSet[Long] + val (elements, foundNullElement) = + evalIntLongPrimitiveType(array1, array2, null, false, true) // allocate result array - hsLong = new OpenHashSet[Long] + hsResultLong = new OpenHashSet[Long] val resultArray = if (UnsafeArrayData.shouldUseGenericArrayData( LongType.defaultSize, elements)) { new GenericArrayData(new Array[Any](elements)) @@ -4097,10 +4108,11 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetL Platform.LONG_ARRAY_OFFSET, elements, LongType.defaultSize) } // assign elements into the result array - evalIntLongPrimitiveType(array1, array2, resultArray, true) + evalIntLongPrimitiveType(array1, array2, resultArray, foundNullElement, true) resultArray case _ => val hs = new OpenHashSet[Any] + val hsResult = new OpenHashSet[Any] var foundNullElement = false var i = 0 while (i < array1.numElements()) { @@ -4122,9 +4134,9 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetL } } else { val elem = array2.get(i, elementType) - if (hs.contains(elem)) { + if (hs.contains(elem) && !hsResult.contains(elem)) { arrayBuffer += elem - hs.remove(elem) + hsResult.add(elem) } } i += 1 @@ -4170,15 +4182,18 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetL if (openHashElementType != "") { // Here, we ensure elementTypeSupportEquals is true val foundNullElement = ctx.freshName("foundNullElement") + val foundNullElementForSize = ctx.freshName("foundNullElementForSize") val openHashSet = classOf[OpenHashSet[_]].getName val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$openHashElementType()" val hs = ctx.freshName("hs") + val hsResult = ctx.freshName("hsResult") val arrayData = classOf[ArrayData].getName val arrays = ctx.freshName("arrays") val array = ctx.freshName("array") val arrayDataIdx = ctx.freshName("arrayDataIdx") s""" |$openHashSet $hs = new $openHashSet$postFix($classTag); + |$openHashSet $hsResult = new $openHashSet$postFix($classTag); |boolean $foundNullElement = false; |int $size = 0; |for (int $i = 0; $i < $array1.numElements(); $i++) { @@ -4188,31 +4203,24 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetL | $hs.add$postFix($array1.$getter); | } |} + |boolean $foundNullElementForSize = $foundNullElement; |for (int $i = 0; $i < $array2.numElements(); $i++) { | if ($array2.isNullAt($i)) { - | if ($foundNullElement) { + | if ($foundNullElementForSize) { | $size++; - | $foundNullElement = false; + | $foundNullElementForSize = false; | } | } else { | $javaTypeName $value = $array2.$getter; - | if ($hs.contains($castOp $value)) { - | $hs.remove$postFix($value); + | if ($hs.contains($castOp $value) && !$hsResult.contains($castOp $value)) { + | $hsResult.add$postFix($value); | $size++; | } | } |} |$arrayBuilder - |$hs = new $openHashSet$postFix($classTag); - |$foundNullElement = false; + |$hsResult = new $openHashSet$postFix($classTag); |int $pos = 0; - |for (int $i = 0; $i < $array1.numElements(); $i++) { - | if ($array1.isNullAt($i)) { - | $foundNullElement = true; - | } else { - | $hs.add$postFix($array1.$getter); - | } - |} |for (int $i = 0; $i < $array2.numElements(); $i++) { | if ($array2.isNullAt($i)) { | if ($foundNullElement) { @@ -4221,8 +4229,8 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetL | } | } else { | $javaTypeName $value = $array2.$getter; - | if ($hs.contains($castOp $value)) { - | $hs.remove$postFix($value); + | if ($hs.contains($castOp $value) && !$hsResult.contains($castOp $value)) { + | $hsResult.add$postFix($value); | ${ev.value}.$setter; | $pos++; | } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index e2ea3912462ba..a4048db73f1d6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1716,7 +1716,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayIntersect(aa1, aa0), Seq[Seq[Int]](Seq[Int](3, 4))) assert(ArrayIntersect(a00, a01).dataType.asInstanceOf[ArrayType].containsNull === false) - assert(ArrayIntersect(a00, a04).dataType.asInstanceOf[ArrayType].containsNull === false) + assert(ArrayIntersect(a00, a04).dataType.asInstanceOf[ArrayType].containsNull === true) assert(ArrayIntersect(a04, a05).dataType.asInstanceOf[ArrayType].containsNull === true) assert(ArrayIntersect(a20, a21).dataType.asInstanceOf[ArrayType].containsNull === false) assert(ArrayIntersect(a23, a24).dataType.asInstanceOf[ArrayType].containsNull === true) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 7003892d1c97d..3a384264b36c8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1699,6 +1699,16 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { intercept[AnalysisException] { df8.selectExpr("array_intersect(a, b)") } + + val df9 = Seq( + (Array[Integer](1, 2), Array[Integer](2)), + (Array[Integer](1, 2), Array[Integer](1, null)), + (Array[Integer](1, null, 3), Array[Integer](1, 2)), + (Array[Integer](1, null), Array[Integer](2, null)) + ).toDF("a", "b") + val result9 = df9.select(array_intersect($"a", $"b")) + val expectedType9 = ArrayType(IntegerType, containsNull = true) + assert(result9.first.schema(0).dataType === expectedType9) } test("transform function - array for primitive type not containing null") { From 90fecd5c8d3c8d38cfd689173a32c5b61596230d Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sun, 22 Jul 2018 02:48:08 +0100 Subject: [PATCH 14/21] update dataType setting --- .../catalyst/expressions/collectionOperations.scala | 2 +- .../expressions/CollectionExpressionsSuite.scala | 2 +- .../org/apache/spark/sql/DataFrameFunctionsSuite.scala | 10 ---------- 3 files changed, 2 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 313e85958cd25..fddfd13ea836b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3981,7 +3981,7 @@ object ArrayUnion { since = "2.4.0") case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetLike { override def dataType: DataType = ArrayType(elementType, - left.dataType.asInstanceOf[ArrayType].containsNull || + left.dataType.asInstanceOf[ArrayType].containsNull && right.dataType.asInstanceOf[ArrayType].containsNull) var hsInt: OpenHashSet[Int] = _ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index a4048db73f1d6..e2ea3912462ba 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1716,7 +1716,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayIntersect(aa1, aa0), Seq[Seq[Int]](Seq[Int](3, 4))) assert(ArrayIntersect(a00, a01).dataType.asInstanceOf[ArrayType].containsNull === false) - assert(ArrayIntersect(a00, a04).dataType.asInstanceOf[ArrayType].containsNull === true) + assert(ArrayIntersect(a00, a04).dataType.asInstanceOf[ArrayType].containsNull === false) assert(ArrayIntersect(a04, a05).dataType.asInstanceOf[ArrayType].containsNull === true) assert(ArrayIntersect(a20, a21).dataType.asInstanceOf[ArrayType].containsNull === false) assert(ArrayIntersect(a23, a24).dataType.asInstanceOf[ArrayType].containsNull === true) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 3a384264b36c8..7003892d1c97d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1699,16 +1699,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { intercept[AnalysisException] { df8.selectExpr("array_intersect(a, b)") } - - val df9 = Seq( - (Array[Integer](1, 2), Array[Integer](2)), - (Array[Integer](1, 2), Array[Integer](1, null)), - (Array[Integer](1, null, 3), Array[Integer](1, 2)), - (Array[Integer](1, null), Array[Integer](2, null)) - ).toDF("a", "b") - val result9 = df9.select(array_intersect($"a", $"b")) - val expectedType9 = ArrayType(IntegerType, containsNull = true) - assert(result9.first.schema(0).dataType === expectedType9) } test("transform function - array for primitive type not containing null") { From 663358ee2b6dec019dbb8e895a55761bb4a24f35 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 27 Jul 2018 20:13:13 +0100 Subject: [PATCH 15/21] update --- .../spark/util/collection/OpenHashSet.scala | 2 +- python/pyspark/sql/functions.py | 2 +- .../expressions/collectionOperations.scala | 302 +++++++----------- .../CollectionExpressionsSuite.scala | 57 ++-- .../spark/sql/DataFrameFunctionsSuite.scala | 10 +- 5 files changed, 164 insertions(+), 209 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala index 8be050ecfba87..8883e17bf3164 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala @@ -276,7 +276,7 @@ class OpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag]( private def nextPowerOf2(n: Int): Int = { if (n == 0) { - 2 + 1 } else { val highBit = Integer.highestOneBit(n) if (highBit == n) n else highBit << 1 diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 0e1c1162ee01a..eaecf284b51f1 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2046,7 +2046,7 @@ def array_intersect(col1, col2): >>> from pyspark.sql import Row >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])]) >>> df.select(array_intersect(df.c1, df.c2)).collect() - [Row(array_intersect(c1, c2)=[u'c', u'a'])] + [Row(array_intersect(c1, c2)=[u'a', u'c'])] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.array_intersect(_to_java_column(col1), _to_java_column(col2))) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index fddfd13ea836b..b455ccf26c4aa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3984,198 +3984,143 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetL left.dataType.asInstanceOf[ArrayType].containsNull && right.dataType.asInstanceOf[ArrayType].containsNull) - var hsInt: OpenHashSet[Int] = _ - var hsResultInt: OpenHashSet[Int] = _ - var hsLong: OpenHashSet[Long] = _ - var hsResultLong: OpenHashSet[Long] = _ - - def assignInt(array: ArrayData, idx: Int, resultArray: ArrayData, pos: Int): Boolean = { - val elem = array.getInt(idx) - if (hsInt.contains(elem) && !hsResultInt.contains(elem)) { - if (resultArray != null) { - resultArray.setInt(pos, elem) - } - hsResultInt.add(elem) - true - } else { - false - } - } - - def assignLong(array: ArrayData, idx: Int, resultArray: ArrayData, pos: Int): Boolean = { - val elem = array.getLong(idx) - if (hsLong.contains(elem) && !hsResultLong.contains(elem)) { - if (resultArray != null) { - resultArray.setLong(pos, elem) - } - hsResultLong.add(elem) - true - } else { - false - } - } - - def evalIntLongPrimitiveType( - array1: ArrayData, - array2: ArrayData, - resultArray: ArrayData, - initFoundNullElement: Boolean, - isLongType: Boolean): (Int, Boolean) = { - // store elements into resultArray - var i = 0 - var foundNullElement = initFoundNullElement - if (resultArray == null) { - // hsInt or hsLong is updated only once since it is not changed - while (i < array1.numElements()) { - if (array1.isNullAt(i)) { - foundNullElement = true - } else { - val assigned = if (!isLongType) { - hsInt.add(array1.getInt(i)) + @transient lazy val evalIntersect: (ArrayData, ArrayData) => ArrayData = { + if (elementTypeSupportEquals) { + (array1, array2) => + val hs = new OpenHashSet[Any] + val hsResult = new OpenHashSet[Any] + var foundNullElement = false + var i = 0 + while (i < array2.numElements()) { + if (array2.isNullAt(i)) { + foundNullElement = true } else { - hsLong.add(array1.getLong(i)) + val elem = array2.get(i, elementType) + hs.add(elem) } + i += 1 } - i += 1 - } - } - var pos = 0 - i = 0 - var foundNullElementForResult = foundNullElement - while (i < array2.numElements()) { - if (array2.isNullAt(i)) { - if (foundNullElementForResult) { - if (resultArray != null) { - resultArray.setNullAt(pos) + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + i = 0 + while (i < array1.numElements()) { + if (array1.isNullAt(i)) { + if (foundNullElement) { + arrayBuffer += null + foundNullElement = false + } + } else { + val elem = array1.get(i, elementType) + if (hs.contains(elem) && !hsResult.contains(elem)) { + arrayBuffer += elem + hsResult.add(elem) + } } - pos += 1 - foundNullElementForResult = false - } - } else { - val assigned = if (!isLongType) { - assignInt(array2, i, resultArray, pos) - } else { - assignLong(array2, i, resultArray, pos) + i += 1 } - if (assigned) { - pos += 1 + new GenericArrayData(arrayBuffer) + } else { + (array1, array2) => + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + var alreadySeenNull = false + var i = 0 + while (i < array1.numElements()) { + var found = false + val elem1 = array1.get(i, elementType) + if (array1.isNullAt(i)) { + if (!alreadySeenNull) { + var j = 0 + while (!found && j < array2.numElements()) { + found = array2.isNullAt(j) + j += 1 + } + // array2 is scanned only once for null element + alreadySeenNull = true + } + } else { + var j = 0 + while (!found && j < array2.numElements()) { + if (!array2.isNullAt(j)) { + val elem2 = array2.get(j, elementType) + if (ordering.equiv(elem1, elem2)) { + // check whether elem1 is already stored in arrayBuffer + var foundArrayBuffer = false + var k = 0 + while (!foundArrayBuffer && k < arrayBuffer.size) { + val va = arrayBuffer(k) + foundArrayBuffer = (va != null) && ordering.equiv(va, elem1) + k += 1 + } + found = !foundArrayBuffer + } + } + j += 1 + } + } + if (found) { + arrayBuffer += elem1 + } + i += 1 } - } - i += 1 + new GenericArrayData(arrayBuffer) } - (pos, foundNullElement) } override def nullSafeEval(input1: Any, input2: Any): Any = { val array1 = input1.asInstanceOf[ArrayData] val array2 = input2.asInstanceOf[ArrayData] - if (elementTypeSupportEquals) { - elementType match { - case IntegerType => - // avoid boxing of primitive int array elements - // calculate result array size - hsInt = new OpenHashSet[Int] - hsResultInt = new OpenHashSet[Int] - val (elements, foundNullElement) = - evalIntLongPrimitiveType(array1, array2, null, false, false) - // allocate result array - hsResultInt = new OpenHashSet[Int] - val resultArray = if (UnsafeArrayData.shouldUseGenericArrayData( - IntegerType.defaultSize, elements)) { - new GenericArrayData(new Array[Any](elements)) - } else { - UnsafeArrayData.forPrimitiveArray( - Platform.INT_ARRAY_OFFSET, elements, IntegerType.defaultSize) - } - // assign elements into the result array - evalIntLongPrimitiveType(array1, array2, resultArray, foundNullElement, false) - resultArray - case LongType => - // avoid boxing of primitive long array elements - // calculate result array size - hsLong = new OpenHashSet[Long] - hsResultLong = new OpenHashSet[Long] - val (elements, foundNullElement) = - evalIntLongPrimitiveType(array1, array2, null, false, true) - // allocate result array - hsResultLong = new OpenHashSet[Long] - val resultArray = if (UnsafeArrayData.shouldUseGenericArrayData( - LongType.defaultSize, elements)) { - new GenericArrayData(new Array[Any](elements)) - } else { - UnsafeArrayData.forPrimitiveArray( - Platform.LONG_ARRAY_OFFSET, elements, LongType.defaultSize) - } - // assign elements into the result array - evalIntLongPrimitiveType(array1, array2, resultArray, foundNullElement, true) - resultArray - case _ => - val hs = new OpenHashSet[Any] - val hsResult = new OpenHashSet[Any] - var foundNullElement = false - var i = 0 - while (i < array1.numElements()) { - if (array1.isNullAt(i)) { - foundNullElement = true - } else { - val elem = array1.get(i, elementType) - hs.add(elem) - } - i += 1 - } - val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] - i = 0 - while (i < array2.numElements()) { - if (array2.isNullAt(i)) { - if (foundNullElement) { - arrayBuffer += null - foundNullElement = false - } - } else { - val elem = array2.get(i, elementType) - if (hs.contains(elem) && !hsResult.contains(elem)) { - arrayBuffer += elem - hsResult.add(elem) - } - } - i += 1 - } - new GenericArrayData(arrayBuffer) - } - } else { - ArrayIntersect.intersectOrdering(array1, array2, elementType, ordering) - } + evalIntersect(array1, array2) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val arrayData = classOf[ArrayData].getName val i = ctx.freshName("i") val pos = ctx.freshName("pos") val value = ctx.freshName("value") + val hsValue = ctx.freshName("hsValue") val size = ctx.freshName("size") - val (postFix, openHashElementType, getter, setter, javaTypeName, castOp, arrayBuilder) = + val (postFix, openHashElementType, hsJavaTypeName, genHsValue, + getter, setter, javaTypeName, arrayBuilder) = if (elementTypeSupportEquals) { elementType match { - case ByteType | ShortType | IntegerType | LongType => + case BooleanType | ByteType | ShortType | IntegerType => val ptName = CodeGenerator.primitiveTypeName(elementType) val unsafeArray = ctx.freshName("unsafeArray") - (if (elementType == LongType) s"$$mcJ$$sp" else s"$$mcI$$sp", - if (elementType == LongType) "Long" else "Int", + ("$mcI$sp", "Int", "int", + if (elementType != BooleanType) { + s"(int) $value" + } else { + s"$value ? 1 : 0;" + }, s"get$ptName($i)", s"set$ptName($pos, $value)", CodeGenerator.javaType(elementType), - if (elementType == LongType) "(long)" else "(int)", s""" - |${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")} - |${ev.value} = $unsafeArray; - """.stripMargin) + |${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")} + |${ev.value} = $unsafeArray; + """.stripMargin) + case LongType | FloatType | DoubleType => + val ptName = CodeGenerator.primitiveTypeName(elementType) + val unsafeArray = ctx.freshName("unsafeArray") + val signature = elementType match { + case LongType => "$mcJ$sp" + case FloatType => "$mcF$sp" + case DoubleType => "$mcD$sp" + } + (signature, CodeGenerator.boxedType(elementType), + CodeGenerator.javaType(elementType), value, + s"get$ptName($i)", s"set$ptName($pos, $value)", CodeGenerator.javaType(elementType), + s""" + |${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")} + |${ev.value} = $unsafeArray; + """.stripMargin) case _ => val genericArrayData = classOf[GenericArrayData].getName val et = ctx.addReferenceObj("elementType", elementType) - ("", "Object", - s"get($i, $et)", s"update($pos, $value)", "Object", "", + ("", "Object", "Object", value, + s"get($i, $et)", s"update($pos, $value)", "Object", s"${ev.value} = new $genericArrayData(new Object[$size]);") } } else { - ("", "", "", "", "", "", "") + ("", "", "", "", "", "", "", "") } nullSafeCodeGen(ctx, ev, (array1, array2) => { @@ -4196,24 +4141,27 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetL |$openHashSet $hsResult = new $openHashSet$postFix($classTag); |boolean $foundNullElement = false; |int $size = 0; - |for (int $i = 0; $i < $array1.numElements(); $i++) { - | if ($array1.isNullAt($i)) { + |for (int $i = 0; $i < $array2.numElements(); $i++) { + | if ($array2.isNullAt($i)) { | $foundNullElement = true; | } else { - | $hs.add$postFix($array1.$getter); + | $javaTypeName $value = $array2.$getter; + | $hsJavaTypeName $hsValue = $genHsValue; + | $hs.add$postFix($hsValue); | } |} |boolean $foundNullElementForSize = $foundNullElement; - |for (int $i = 0; $i < $array2.numElements(); $i++) { - | if ($array2.isNullAt($i)) { + |for (int $i = 0; $i < $array1.numElements(); $i++) { + | if ($array1.isNullAt($i)) { | if ($foundNullElementForSize) { | $size++; | $foundNullElementForSize = false; | } | } else { - | $javaTypeName $value = $array2.$getter; - | if ($hs.contains($castOp $value) && !$hsResult.contains($castOp $value)) { - | $hsResult.add$postFix($value); + | $javaTypeName $value = $array1.$getter; + | $hsJavaTypeName $hsValue = $genHsValue; + | if ($hs.contains($hsValue) && !$hsResult.contains($hsValue)) { + | $hsResult.add$postFix($hsValue); | $size++; | } | } @@ -4221,16 +4169,17 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetL |$arrayBuilder |$hsResult = new $openHashSet$postFix($classTag); |int $pos = 0; - |for (int $i = 0; $i < $array2.numElements(); $i++) { - | if ($array2.isNullAt($i)) { + |for (int $i = 0; $i < $array1.numElements(); $i++) { + | if ($array1.isNullAt($i)) { | if ($foundNullElement) { | ${ev.value}.setNullAt($pos++); | $foundNullElement = false; | } | } else { - | $javaTypeName $value = $array2.$getter; - | if ($hs.contains($castOp $value) && !$hsResult.contains($castOp $value)) { - | $hsResult.add$postFix($value); + | $javaTypeName $value = $array1.$getter; + | $hsJavaTypeName $hsValue = $genHsValue; + | if ($hs.contains($hsValue) && !$hsResult.contains($hsValue)) { + | $hsResult.add$postFix($hsValue); | ${ev.value}.$setter; | $pos++; | } @@ -4238,11 +4187,8 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetL |} """.stripMargin } else { - val arrayIntersect = classOf[ArrayIntersect].getName - val et = ctx.addReferenceObj("elementTypeIntersect", elementType) - val order = ctx.addReferenceObj("orderingIntersect", ordering) - val method = "intersectOrdering" - s"${ev.value} = $arrayIntersect$$.MODULE$$.$method($array1, $array2, $et, $order);" + val expr = ctx.addReferenceObj("arrayIntersectExpr", this) + s"${ev.value} = ($arrayData)$expr.nullSafeEval($array1, $array2);" } }) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index e2ea3912462ba..4daa113869b5d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1627,10 +1627,16 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val a04 = Literal.create(Seq(1, 2, null, 4, 5, null), ArrayType(IntegerType, true)) val a05 = Literal.create(Seq(-5, 4, null, 2, -1, null), ArrayType(IntegerType, true)) val a06 = Literal.create(Seq.empty[Int], ArrayType(IntegerType, false)) + val abl0 = Literal.create(Seq[Boolean](true, false, true), ArrayType(BooleanType, false)) + val abl1 = Literal.create(Seq[Boolean](true, true), ArrayType(BooleanType, false)) val ab0 = Literal.create(Seq[Byte](1, 2, 3, 2), ArrayType(ByteType, containsNull = false)) val ab1 = Literal.create(Seq[Byte](4, 2, 4), ArrayType(ByteType, containsNull = false)) val as0 = Literal.create(Seq[Short](1, 2, 3, 2), ArrayType(ShortType, containsNull = false)) val as1 = Literal.create(Seq[Short](4, 2, 4), ArrayType(ShortType, containsNull = false)) + val af0 = Literal.create(Seq[Float](1.1F, 2.2F, 3.3F, 2.2F), ArrayType(FloatType, false)) + val af1 = Literal.create(Seq[Float](4.4F, 2.2F, 4.4F), ArrayType(FloatType, false)) + val ad0 = Literal.create(Seq[Double](1.1, 2.2, 3.3, 2.2), ArrayType(DoubleType, false)) + val ad1 = Literal.create(Seq[Double](4.4, 2.2, 4.4), ArrayType(DoubleType, false)) val a10 = Literal.create(Seq(1L, 2L, 4L), ArrayType(LongType, false)) val a11 = Literal.create(Seq(4L, 2L), ArrayType(LongType, false)) @@ -1643,37 +1649,40 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val a20 = Literal.create(Seq("b", "a", "c"), ArrayType(StringType, false)) val a21 = Literal.create(Seq("c", "a"), ArrayType(StringType, false)) val a22 = Literal.create(Seq("b", "a", "c", "a"), ArrayType(StringType, false)) - val a23 = Literal.create(Seq("c", null, "a", "f"), ArrayType(StringType, true)) + val a23 = Literal.create(Seq("c", "a", null, "f"), ArrayType(StringType, true)) val a24 = Literal.create(Seq("b", null, "a", "g", null), ArrayType(StringType, true)) val a25 = Literal.create(Seq.empty[String], ArrayType(StringType, false)) val a30 = Literal.create(Seq(null, null), ArrayType(IntegerType)) val a31 = Literal.create(null, ArrayType(StringType)) - checkEvaluation(ArrayIntersect(a00, a01), Seq(4, 2)) - checkEvaluation(ArrayIntersect(a01, a00), Seq(2, 4)) - checkEvaluation(ArrayIntersect(a02, a03), Seq(4, 2)) - checkEvaluation(ArrayIntersect(a03, a02), Seq(2, 4)) + checkEvaluation(ArrayIntersect(a00, a01), Seq(2, 4)) + checkEvaluation(ArrayIntersect(a01, a00), Seq(4, 2)) + checkEvaluation(ArrayIntersect(a02, a03), Seq(2, 4)) + checkEvaluation(ArrayIntersect(a03, a02), Seq(4, 2)) checkEvaluation(ArrayIntersect(a00, a04), Seq(1, 2, 4)) - checkEvaluation(ArrayIntersect(a04, a05), Seq(4, null, 2)) + checkEvaluation(ArrayIntersect(a04, a05), Seq(2, null, 4)) checkEvaluation(ArrayIntersect(a02, a06), Seq.empty) checkEvaluation(ArrayIntersect(a06, a04), Seq.empty) + checkEvaluation(ArrayIntersect(abl0, abl1), Seq[Boolean](true)) checkEvaluation(ArrayIntersect(ab0, ab1), Seq[Byte](2)) checkEvaluation(ArrayIntersect(as0, as1), Seq[Short](2)) - - checkEvaluation(ArrayIntersect(a10, a11), Seq(4L, 2L)) - checkEvaluation(ArrayIntersect(a11, a10), Seq(2L, 4L)) - checkEvaluation(ArrayIntersect(a12, a13), Seq(4L, 2L)) - checkEvaluation(ArrayIntersect(a13, a12), Seq(2L, 4L)) - checkEvaluation(ArrayIntersect(a14, a15), Seq(4L, null, 2L)) + checkEvaluation(ArrayIntersect(af0, af1), Seq[Float](2.2F)) + checkEvaluation(ArrayIntersect(ad0, ad1), Seq[Double](2.2D)) + + checkEvaluation(ArrayIntersect(a10, a11), Seq(2L, 4L)) + checkEvaluation(ArrayIntersect(a11, a10), Seq(4L, 2L)) + checkEvaluation(ArrayIntersect(a12, a13), Seq(2L, 4L)) + checkEvaluation(ArrayIntersect(a13, a12), Seq(4L, 2L)) + checkEvaluation(ArrayIntersect(a14, a15), Seq(2L, null, 4L)) checkEvaluation(ArrayIntersect(a12, a16), Seq.empty) checkEvaluation(ArrayIntersect(a16, a14), Seq.empty) - checkEvaluation(ArrayIntersect(a20, a21), Seq("c", "a")) - checkEvaluation(ArrayIntersect(a21, a20), Seq("a", "c")) - checkEvaluation(ArrayIntersect(a22, a21), Seq("c", "a")) - checkEvaluation(ArrayIntersect(a21, a22), Seq("a", "c")) - checkEvaluation(ArrayIntersect(a23, a24), Seq(null, "a")) + checkEvaluation(ArrayIntersect(a20, a21), Seq("a", "c")) + checkEvaluation(ArrayIntersect(a21, a20), Seq("c", "a")) + checkEvaluation(ArrayIntersect(a22, a21), Seq("a", "c")) + checkEvaluation(ArrayIntersect(a21, a22), Seq("c", "a")) + checkEvaluation(ArrayIntersect(a23, a24), Seq("a", null)) checkEvaluation(ArrayIntersect(a24, a23), Seq(null, "a")) checkEvaluation(ArrayIntersect(a24, a25), Seq.empty) checkEvaluation(ArrayIntersect(a25, a24), Seq.empty) @@ -1689,21 +1698,21 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](3, 4), Array[Byte](5, 6)), ArrayType(BinaryType)) val b2 = Literal.create( - Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](3, 4), Array[Byte](1, 2)), + Seq[Array[Byte]](Array[Byte](3, 4), Array[Byte](1, 2), Array[Byte](1, 2)), ArrayType(BinaryType)) val b3 = Literal.create(Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](3, 4), null), ArrayType(BinaryType)) val b4 = Literal.create(Seq[Array[Byte]](null, Array[Byte](3, 4), null), ArrayType(BinaryType)) val b5 = Literal.create(Seq.empty, ArrayType(BinaryType)) val arrayWithBinaryNull = Literal.create(Seq(null), ArrayType(BinaryType)) - checkEvaluation(ArrayIntersect(b0, b1), Seq[Array[Byte]](Array[Byte](3, 4), Array[Byte](5, 6))) - checkEvaluation(ArrayIntersect(b1, b0), Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](3, 4))) + checkEvaluation(ArrayIntersect(b0, b1), Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](3, 4))) + checkEvaluation(ArrayIntersect(b1, b0), Seq[Array[Byte]](Array[Byte](3, 4), Array[Byte](5, 6))) checkEvaluation(ArrayIntersect(b0, b2), Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](3, 4))) - checkEvaluation(ArrayIntersect(b2, b0), Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](3, 4))) - checkEvaluation(ArrayIntersect(b2, b3), Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](3, 4))) + checkEvaluation(ArrayIntersect(b2, b0), Seq[Array[Byte]](Array[Byte](3, 4), Array[Byte](1, 2))) + checkEvaluation(ArrayIntersect(b2, b3), Seq[Array[Byte]](Array[Byte](3, 4), Array[Byte](1, 2))) checkEvaluation(ArrayIntersect(b3, b2), Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](3, 4))) - checkEvaluation(ArrayIntersect(b3, b4), Seq[Array[Byte]](null, Array[Byte](3, 4))) - checkEvaluation(ArrayIntersect(b4, b3), Seq[Array[Byte]](Array[Byte](3, 4), null)) + checkEvaluation(ArrayIntersect(b3, b4), Seq[Array[Byte]](Array[Byte](3, 4), null)) + checkEvaluation(ArrayIntersect(b4, b3), Seq[Array[Byte]](null, Array[Byte](3, 4))) checkEvaluation(ArrayIntersect(b4, b5), Seq.empty) checkEvaluation(ArrayIntersect(b5, b4), Seq.empty) checkEvaluation(ArrayIntersect(b4, arrayWithBinaryNull), Seq[Array[Byte]](null)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 7003892d1c97d..011d2022035dc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1649,29 +1649,29 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { test("array_intersect functions") { val df1 = Seq((Array(1, 2, 4), Array(4, 2))).toDF("a", "b") - val ans1 = Row(Seq(4, 2)) + val ans1 = Row(Seq(2, 4)) checkAnswer(df1.select(array_intersect($"a", $"b")), ans1) checkAnswer(df1.selectExpr("array_intersect(a, b)"), ans1) val df2 = Seq((Array[Integer](1, 2, null, 4, 5), Array[Integer](-5, 4, null, 2, -1))) .toDF("a", "b") - val ans2 = Row(Seq(4, null, 2)) + val ans2 = Row(Seq(2, null, 4)) checkAnswer(df2.select(array_intersect($"a", $"b")), ans2) checkAnswer(df2.selectExpr("array_intersect(a, b)"), ans2) val df3 = Seq((Array(1L, 2L, 4L), Array(4L, 2L))).toDF("a", "b") - val ans3 = Row(Seq(4L, 2L)) + val ans3 = Row(Seq(2L, 4L)) checkAnswer(df3.select(array_intersect($"a", $"b")), ans3) checkAnswer(df3.selectExpr("array_intersect(a, b)"), ans3) val df4 = Seq( (Array[java.lang.Long](1L, 2L, null, 4L, 5L), Array[java.lang.Long](-5L, 4L, null, 2L, -1L))) .toDF("a", "b") - val ans4 = Row(Seq(4L, null, 2L)) + val ans4 = Row(Seq(2L, null, 4L)) checkAnswer(df4.select(array_intersect($"a", $"b")), ans4) checkAnswer(df4.selectExpr("array_intersect(a, b)"), ans4) - val df5 = Seq((Array("c", null, "a", "f"), Array("b", null, "a", "g"))).toDF("a", "b") + val df5 = Seq((Array("c", null, "a", "f"), Array("b", "a", null, "g"))).toDF("a", "b") val ans5 = Row(Seq(null, "a")) checkAnswer(df5.select(array_intersect($"a", $"b")), ans5) checkAnswer(df5.selectExpr("array_intersect(a, b)"), ans5) From 3ef99b8f311c9e3bd18e0d8ccb2750c8e79fac85 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sun, 29 Jul 2018 06:39:15 +0100 Subject: [PATCH 16/21] optimize nullchecks in generated code --- .../expressions/collectionOperations.scala | 56 ++++++++++++++----- 1 file changed, 41 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index b455ccf26c4aa..a340c88de0c89 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -4136,15 +4136,49 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetL val arrays = ctx.freshName("arrays") val array = ctx.freshName("array") val arrayDataIdx = ctx.freshName("arrayDataIdx") + + val array2NullCheck = if (right.dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($array2.isNullAt($i)) { + | $foundNullElement = true; + |} else + """.stripMargin + } else { + "" + } + val array1NullCheck = if (left.dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($array1.isNullAt($i)) { + | if ($foundNullElementForSize) { + | $size++; + | $foundNullElementForSize = false; + | } + |} else + """.stripMargin + } else { + "" + } + val array1NullAssignment = if (left.dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($array1.isNullAt($i)) { + | if ($foundNullElement) { + | ${ev.value}.setNullAt($pos++); + | $foundNullElement = false; + | } + |} else + """.stripMargin + } else { + "" + } + s""" |$openHashSet $hs = new $openHashSet$postFix($classTag); |$openHashSet $hsResult = new $openHashSet$postFix($classTag); |boolean $foundNullElement = false; |int $size = 0; |for (int $i = 0; $i < $array2.numElements(); $i++) { - | if ($array2.isNullAt($i)) { - | $foundNullElement = true; - | } else { + | $array2NullCheck + | { | $javaTypeName $value = $array2.$getter; | $hsJavaTypeName $hsValue = $genHsValue; | $hs.add$postFix($hsValue); @@ -4152,12 +4186,8 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetL |} |boolean $foundNullElementForSize = $foundNullElement; |for (int $i = 0; $i < $array1.numElements(); $i++) { - | if ($array1.isNullAt($i)) { - | if ($foundNullElementForSize) { - | $size++; - | $foundNullElementForSize = false; - | } - | } else { + | $array1NullCheck + | { | $javaTypeName $value = $array1.$getter; | $hsJavaTypeName $hsValue = $genHsValue; | if ($hs.contains($hsValue) && !$hsResult.contains($hsValue)) { @@ -4170,12 +4200,8 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetL |$hsResult = new $openHashSet$postFix($classTag); |int $pos = 0; |for (int $i = 0; $i < $array1.numElements(); $i++) { - | if ($array1.isNullAt($i)) { - | if ($foundNullElement) { - | ${ev.value}.setNullAt($pos++); - | $foundNullElement = false; - | } - | } else { + | $array1NullAssignment + | { | $javaTypeName $value = $array1.$getter; | $hsJavaTypeName $hsValue = $genHsValue; | if ($hs.contains($hsValue) && !$hsResult.contains($hsValue)) { From 3462b98c66eb9c59b5faf292cd914d6d6ec2f809 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 31 Jul 2018 19:11:48 +0100 Subject: [PATCH 17/21] update --- .../expressions/collectionOperations.scala | 211 ++++++++++-------- 1 file changed, 123 insertions(+), 88 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index a340c88de0c89..a960857aefa7f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -4079,27 +4079,21 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetL val value = ctx.freshName("value") val hsValue = ctx.freshName("hsValue") val size = ctx.freshName("size") - val (postFix, openHashElementType, hsJavaTypeName, genHsValue, - getter, setter, javaTypeName, arrayBuilder) = - if (elementTypeSupportEquals) { + if (elementTypeSupportEquals) { + val ptName = CodeGenerator.primitiveTypeName(elementType) + val unsafeArray = ctx.freshName("unsafeArray") + val (postFix, openHashElementType, hsJavaTypeName, genHsValue, + getter, setter, javaTypeName, primitiveTypeName, arrayDataBuilder) = elementType match { - case BooleanType | ByteType | ShortType | IntegerType => - val ptName = CodeGenerator.primitiveTypeName(elementType) - val unsafeArray = ctx.freshName("unsafeArray") - ("$mcI$sp", "Int", "int", - if (elementType != BooleanType) { - s"(int) $value" - } else { - s"$value ? 1 : 0;" - }, - s"get$ptName($i)", s"set$ptName($pos, $value)", CodeGenerator.javaType(elementType), + case ByteType | ShortType | IntegerType => + ("$mcI$sp", "Int", "int", s"(int) $value", + s"get$ptName($i)", s"set$ptName($pos, $value)", + CodeGenerator.javaType(elementType), ptName, s""" |${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")} |${ev.value} = $unsafeArray; """.stripMargin) case LongType | FloatType | DoubleType => - val ptName = CodeGenerator.primitiveTypeName(elementType) - val unsafeArray = ctx.freshName("unsafeArray") val signature = elementType match { case LongType => "$mcJ$sp" case FloatType => "$mcF$sp" @@ -4107,7 +4101,8 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetL } (signature, CodeGenerator.boxedType(elementType), CodeGenerator.javaType(elementType), value, - s"get$ptName($i)", s"set$ptName($pos, $value)", CodeGenerator.javaType(elementType), + s"get$ptName($i)", s"set$ptName($pos, $value)", + CodeGenerator.javaType(elementType), ptName, s""" |${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")} |${ev.value} = $unsafeArray; @@ -4116,107 +4111,147 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetL val genericArrayData = classOf[GenericArrayData].getName val et = ctx.addReferenceObj("elementType", elementType) ("", "Object", "Object", value, - s"get($i, $et)", s"update($pos, $value)", "Object", + s"get($i, $et)", s"update($pos, $value)", "Object", "Ref", s"${ev.value} = new $genericArrayData(new Object[$size]);") } - } else { - ("", "", "", "", "", "", "", "") - } - nullSafeCodeGen(ctx, ev, (array1, array2) => { - if (openHashElementType != "") { - // Here, we ensure elementTypeSupportEquals is true + nullSafeCodeGen(ctx, ev, (array1, array2) => { val foundNullElement = ctx.freshName("foundNullElement") - val foundNullElementForSize = ctx.freshName("foundNullElementForSize") + val nullElementIndex = ctx.freshName("nullElementIndex") + val builder = ctx.freshName("builder") + val array = ctx.freshName("array") val openHashSet = classOf[OpenHashSet[_]].getName val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$openHashElementType()" val hs = ctx.freshName("hs") val hsResult = ctx.freshName("hsResult") - val arrayData = classOf[ArrayData].getName - val arrays = ctx.freshName("arrays") - val array = ctx.freshName("array") - val arrayDataIdx = ctx.freshName("arrayDataIdx") + val genericArrayData = classOf[GenericArrayData].getName + val arrayBuilder = "scala.collection.mutable.ArrayBuilder" + val arrayBuilderClass = s"$arrayBuilder$$of$primitiveTypeName" + val arrayBuilderClassTag = if (primitiveTypeName != "Ref") { + s"scala.reflect.ClassTag$$.MODULE$$.$primitiveTypeName()" + } else { + s"scala.reflect.ClassTag$$.MODULE$$.AnyRef()" + } - val array2NullCheck = if (right.dataType.asInstanceOf[ArrayType].containsNull) { + def withArray2NullCheck(body: String) = + if (right.dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($array2.isNullAt($i)) { + | $foundNullElement = true; + |} else { + | $body + |} + """.stripMargin + } else { + body + } + val array2Body = s""" - |if ($array2.isNullAt($i)) { - | $foundNullElement = true; - |} else + |$javaTypeName $value = $array2.$getter; + |$hsJavaTypeName $hsValue = $genHsValue; + |$hs.add$postFix($hsValue); """.stripMargin - } else { - "" - } - val array1NullCheck = if (left.dataType.asInstanceOf[ArrayType].containsNull) { + + def withArray1NullAssignment(body: String) = + if (left.dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($array1.isNullAt($i)) { + | if ($foundNullElement) { + | $nullElementIndex = $size; + | $foundNullElement = false; + | $size++; + | } + |} else { + | $body + |} + """.stripMargin + } else { + body + } + val array1Body = s""" - |if ($array1.isNullAt($i)) { - | if ($foundNullElementForSize) { - | $size++; - | $foundNullElementForSize = false; + |$javaTypeName $value = $array1.$getter; + |$hsJavaTypeName $hsValue = $genHsValue; + |if ($hs.contains($hsValue) && !$hsResult.contains($hsValue)) { + | if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | break; | } - |} else + | $hsResult.add$postFix($hsValue); + | $builder.$$plus$$eq($value); + |} """.stripMargin - } else { - "" - } - val array1NullAssignment = if (left.dataType.asInstanceOf[ArrayType].containsNull) { + + val nonNullArrayDataBuild = { + val build = if (postFix != "") { + val defaultSize = elementType.defaultSize + s""" + |if (!UnsafeArrayData.shouldUseGenericArrayData($defaultSize, $size)) { + | ${ev.value} = UnsafeArrayData.fromPrimitiveArray($builder.result()); + |} else { + | ${ev.value} = new $genericArrayData($builder.result()); + |} + """.stripMargin + } else { + s"${ev.value} = new $genericArrayData($builder.result());" + } s""" - |if ($array1.isNullAt($i)) { - | if ($foundNullElement) { - | ${ev.value}.setNullAt($pos++); - | $foundNullElement = false; - | } - |} else + |if ($size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | throw new RuntimeException("Unsuccessful try create array with " + $size + + | " bytes of data due to exceeding the limit " + + | "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH} elements for GenericArrayData." + + | " $prettyName failed."); + |} + |$build """.stripMargin - } else { - "" } + def buildResultArrayData(nonNullArrayDataBuild: String) = + if (dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($nullElementIndex < 0) { + | // result has no null element + | $nonNullArrayDataBuild + |} else { + | // result has null element + | $arrayDataBuilder + | $javaTypeName[] $array = $builder.result(); + | for (int $i = 0, $pos = 0; $pos < $size; $pos++) { + | if ($pos == $nullElementIndex) { + | ${ev.value}.setNullAt($pos); + | } else { + | $javaTypeName $value = $array[$i++]; + | ${ev.value}.$setter; + | } + | } + |} + """.stripMargin + } else { + nonNullArrayDataBuild + } + s""" |$openHashSet $hs = new $openHashSet$postFix($classTag); |$openHashSet $hsResult = new $openHashSet$postFix($classTag); |boolean $foundNullElement = false; - |int $size = 0; |for (int $i = 0; $i < $array2.numElements(); $i++) { - | $array2NullCheck - | { - | $javaTypeName $value = $array2.$getter; - | $hsJavaTypeName $hsValue = $genHsValue; - | $hs.add$postFix($hsValue); - | } + | ${withArray2NullCheck(array2Body)} |} - |boolean $foundNullElementForSize = $foundNullElement; - |for (int $i = 0; $i < $array1.numElements(); $i++) { - | $array1NullCheck - | { - | $javaTypeName $value = $array1.$getter; - | $hsJavaTypeName $hsValue = $genHsValue; - | if ($hs.contains($hsValue) && !$hsResult.contains($hsValue)) { - | $hsResult.add$postFix($hsValue); - | $size++; - | } - | } - |} - |$arrayBuilder - |$hsResult = new $openHashSet$postFix($classTag); - |int $pos = 0; + |$arrayBuilderClass $builder = + | ($arrayBuilderClass)$arrayBuilder.make($arrayBuilderClassTag); + |int $nullElementIndex = -1; + |int $size = 0; |for (int $i = 0; $i < $array1.numElements(); $i++) { - | $array1NullAssignment - | { - | $javaTypeName $value = $array1.$getter; - | $hsJavaTypeName $hsValue = $genHsValue; - | if ($hs.contains($hsValue) && !$hsResult.contains($hsValue)) { - | $hsResult.add$postFix($hsValue); - | ${ev.value}.$setter; - | $pos++; - | } - | } + | ${withArray1NullAssignment(array1Body)} |} + |${buildResultArrayData(nonNullArrayDataBuild)} """.stripMargin - } else { + }) + } else { + nullSafeCodeGen(ctx, ev, (array1, array2) => { val expr = ctx.addReferenceObj("arrayIntersectExpr", this) s"${ev.value} = ($arrayData)$expr.nullSafeEval($array1, $array2);" - } - }) + }) + } } override def prettyName: String = "array_intersect" From cf742d4e87c56c96896b0e889bf1276728610159 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 2 Aug 2018 09:03:36 +0100 Subject: [PATCH 18/21] refactor methods that are used in Except and Intersect --- .../expressions/collectionOperations.scala | 247 ++++++++---------- 1 file changed, 114 insertions(+), 133 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index a960857aefa7f..90ffbbb6d2632 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3672,6 +3672,113 @@ abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast { case _: AtomicType => true case _ => false } + + protected def makeAccessors( + ctx : CodegenContext, + ev : ExprCode, + unsafeArray : String, + size : String, + value : String, + pos : String, + i : String) : (String, String, String, String, String, String, String, String, String) = { + val ptName = CodeGenerator.primitiveTypeName(elementType) + elementType match { + case ByteType | ShortType | IntegerType => + ("$mcI$sp", "Int", "int", s"(int) $value", + s"get$ptName($i)", s"set$ptName($pos, $value)", + CodeGenerator.javaType(elementType), ptName, + s""" + |${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")} + |${ev.value} = $unsafeArray; + """.stripMargin) + case LongType | FloatType | DoubleType => + val signature = elementType match { + case LongType => "$mcJ$sp" + case FloatType => "$mcF$sp" + case DoubleType => "$mcD$sp" + } + (signature, CodeGenerator.boxedType(elementType), + CodeGenerator.javaType(elementType), value, + s"get$ptName($i)", s"set$ptName($pos, $value)", + CodeGenerator.javaType(elementType), ptName, + s""" + |${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")} + |${ev.value} = $unsafeArray; + """.stripMargin) + case _ => + val genericArrayData = classOf[GenericArrayData].getName + val et = ctx.addReferenceObj("elementType", elementType) + ("", "Object", "Object", value, + s"get($i, $et)", s"update($pos, $value)", "Object", "Ref", + s"${ev.value} = new $genericArrayData(new Object[$size]);") + } + } + + protected def nonNullArrayDataBuild( + mayUseUnsafeArray: Boolean, + ev : ExprCode, + builder : String, + size : String) : String = { + val genericArrayData = classOf[GenericArrayData].getName + val build = if (mayUseUnsafeArray) { + val defaultSize = elementType.defaultSize + s""" + |if (!UnsafeArrayData.shouldUseGenericArrayData($defaultSize, $size)) { + | ${ev.value} = UnsafeArrayData.fromPrimitiveArray($builder.result()); + |} else { + | ${ev.value} = new $genericArrayData($builder.result()); + |} + """.stripMargin + } else { + s"${ev.value} = new $genericArrayData($builder.result());" + } + s""" + |if ($size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | throw new RuntimeException("Unsuccessful try create array with " + $size + + | " bytes of data due to exceeding the limit " + + | "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH} elements for GenericArrayData." + + | " $prettyName failed."); + |} + |$build + """.stripMargin + } + + protected def buildResultArrayData( + nonNullArrayDataBuild : String, + arrayDataBuilder : String, + ev : ExprCode, + builder : String, + setter : String, + javaTypeName : String, + array : String, + value : String, + nullElementIndex : String, + size : String, + i : String, + pos : String) : String = { + if (dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($nullElementIndex < 0) { + | // result has no null element + | $nonNullArrayDataBuild + |} else { + | // result has null element + | $arrayDataBuilder + | $javaTypeName[] $array = $builder.result(); + | for (int $i = 0, $pos = 0; $pos < $size; $pos++) { + | if ($pos == $nullElementIndex) { + | ${ev.value}.setNullAt($pos); + | } else { + | $javaTypeName $value = $array[$i++]; + | ${ev.value}.$setter; + | } + | } + |} + """.stripMargin + } else { + nonNullArrayDataBuild + } + } } object ArraySetLike { @@ -4080,40 +4187,10 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetL val hsValue = ctx.freshName("hsValue") val size = ctx.freshName("size") if (elementTypeSupportEquals) { - val ptName = CodeGenerator.primitiveTypeName(elementType) val unsafeArray = ctx.freshName("unsafeArray") val (postFix, openHashElementType, hsJavaTypeName, genHsValue, - getter, setter, javaTypeName, primitiveTypeName, arrayDataBuilder) = - elementType match { - case ByteType | ShortType | IntegerType => - ("$mcI$sp", "Int", "int", s"(int) $value", - s"get$ptName($i)", s"set$ptName($pos, $value)", - CodeGenerator.javaType(elementType), ptName, - s""" - |${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")} - |${ev.value} = $unsafeArray; - """.stripMargin) - case LongType | FloatType | DoubleType => - val signature = elementType match { - case LongType => "$mcJ$sp" - case FloatType => "$mcF$sp" - case DoubleType => "$mcD$sp" - } - (signature, CodeGenerator.boxedType(elementType), - CodeGenerator.javaType(elementType), value, - s"get$ptName($i)", s"set$ptName($pos, $value)", - CodeGenerator.javaType(elementType), ptName, - s""" - |${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")} - |${ev.value} = $unsafeArray; - """.stripMargin) - case _ => - val genericArrayData = classOf[GenericArrayData].getName - val et = ctx.addReferenceObj("elementType", elementType) - ("", "Object", "Object", value, - s"get($i, $et)", s"update($pos, $value)", "Object", "Ref", - s"${ev.value} = new $genericArrayData(new Object[$size]);") - } + getter, setter, javaTypeName, primitiveTypeName, arrayDataBuilder) = + makeAccessors(ctx, ev, unsafeArray, size, value, pos, i) nullSafeCodeGen(ctx, ev, (array1, array2) => { val foundNullElement = ctx.freshName("foundNullElement") @@ -4124,7 +4201,6 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetL val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$openHashElementType()" val hs = ctx.freshName("hs") val hsResult = ctx.freshName("hsResult") - val genericArrayData = classOf[GenericArrayData].getName val arrayBuilder = "scala.collection.mutable.ArrayBuilder" val arrayBuilderClass = s"$arrayBuilder$$of$primitiveTypeName" val arrayBuilderClassTag = if (primitiveTypeName != "Ref") { @@ -4181,54 +4257,10 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetL |} """.stripMargin - val nonNullArrayDataBuild = { - val build = if (postFix != "") { - val defaultSize = elementType.defaultSize - s""" - |if (!UnsafeArrayData.shouldUseGenericArrayData($defaultSize, $size)) { - | ${ev.value} = UnsafeArrayData.fromPrimitiveArray($builder.result()); - |} else { - | ${ev.value} = new $genericArrayData($builder.result()); - |} - """.stripMargin - } else { - s"${ev.value} = new $genericArrayData($builder.result());" - } - s""" - |if ($size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { - | throw new RuntimeException("Unsuccessful try create array with " + $size + - | " bytes of data due to exceeding the limit " + - | "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH} elements for GenericArrayData." + - | " $prettyName failed."); - |} - |$build - """.stripMargin - } - - def buildResultArrayData(nonNullArrayDataBuild: String) = - if (dataType.asInstanceOf[ArrayType].containsNull) { - s""" - |if ($nullElementIndex < 0) { - | // result has no null element - | $nonNullArrayDataBuild - |} else { - | // result has null element - | $arrayDataBuilder - | $javaTypeName[] $array = $builder.result(); - | for (int $i = 0, $pos = 0; $pos < $size; $pos++) { - | if ($pos == $nullElementIndex) { - | ${ev.value}.setNullAt($pos); - | } else { - | $javaTypeName $value = $array[$i++]; - | ${ev.value}.$setter; - | } - | } - |} - """.stripMargin - } else { - nonNullArrayDataBuild - } - + val nonNullArrayData = nonNullArrayDataBuild(postFix != "", ev, builder, size) + val resultArrayData = buildResultArrayData( + nonNullArrayData, arrayDataBuilder, ev, builder, setter, + javaTypeName, array, value, nullElementIndex, size, i, pos) s""" |$openHashSet $hs = new $openHashSet$postFix($classTag); |$openHashSet $hsResult = new $openHashSet$postFix($classTag); @@ -4243,7 +4275,7 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetL |for (int $i = 0; $i < $array1.numElements(); $i++) { | ${withArray1NullAssignment(array1Body)} |} - |${buildResultArrayData(nonNullArrayDataBuild)} + |$resultArrayData """.stripMargin }) } else { @@ -4539,54 +4571,3 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike override def prettyName: String = "array_except" } - -object ArrayIntersect { - def intersectOrdering( - array1: ArrayData, - array2: ArrayData, - elementType: DataType, - ordering: Ordering[Any]): ArrayData = { - val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] - var alreadySeenNull = false - var i = 0 - while (i < array2.numElements()) { - var found = false - val elem2 = array2.get(i, elementType) - if (array2.isNullAt(i)) { - if (!alreadySeenNull) { - var j = 0 - while (!found && j < array1.numElements()) { - found = array1.isNullAt(j) - j += 1 - } - // array1 is scaned only once for null element - alreadySeenNull = true - } - } else { - var j = 0 - while (!found && j < array1.numElements()) { - if (!array1.isNullAt(j)) { - val elem1 = array1.get(j, elementType) - if (ordering.equiv(elem1, elem2)) { - // check whether elem1 is already stored in arrayBuffer - var foundArrayBuffer = false - var k = 0 - while (!foundArrayBuffer && k < arrayBuffer.size) { - val va = arrayBuffer(k) - foundArrayBuffer = (va != null) && ordering.equiv(va, elem1) - k += 1 - } - found = !foundArrayBuffer - } - } - j += 1 - } - } - if (found) { - arrayBuffer += elem2 - } - i += 1 - } - new GenericArrayData(arrayBuffer) - } -} From 6fba1ee8c3525a6f34bf5580737d067a8f0d976d Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 3 Aug 2018 21:29:46 +0100 Subject: [PATCH 19/21] update --- .../expressions/collectionOperations.scala | 352 +++++++----------- 1 file changed, 137 insertions(+), 215 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 90ffbbb6d2632..ccdd8aaa96aa1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3673,112 +3673,74 @@ abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast { case _ => false } - protected def makeAccessors( - ctx : CodegenContext, - ev : ExprCode, - unsafeArray : String, - size : String, - value : String, - pos : String, - i : String) : (String, String, String, String, String, String, String, String, String) = { - val ptName = CodeGenerator.primitiveTypeName(elementType) + @transient protected lazy val canUseSpecializedHashSet = elementType match { + case ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => true + case _ => false + } + + protected def genGetValue(array: String, i: String): String = + CodeGenerator.getValue(array, elementType, i) + + @transient protected lazy val (hsPostFix, hsTypeName) = { + val ptName = CodeGenerator.primitiveTypeName (elementType) elementType match { - case ByteType | ShortType | IntegerType => - ("$mcI$sp", "Int", "int", s"(int) $value", - s"get$ptName($i)", s"set$ptName($pos, $value)", - CodeGenerator.javaType(elementType), ptName, - s""" - |${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")} - |${ev.value} = $unsafeArray; - """.stripMargin) - case LongType | FloatType | DoubleType => - val signature = elementType match { - case LongType => "$mcJ$sp" - case FloatType => "$mcF$sp" - case DoubleType => "$mcD$sp" - } - (signature, CodeGenerator.boxedType(elementType), - CodeGenerator.javaType(elementType), value, - s"get$ptName($i)", s"set$ptName($pos, $value)", - CodeGenerator.javaType(elementType), ptName, - s""" - |${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")} - |${ev.value} = $unsafeArray; - """.stripMargin) - case _ => - val genericArrayData = classOf[GenericArrayData].getName - val et = ctx.addReferenceObj("elementType", elementType) - ("", "Object", "Object", value, - s"get($i, $et)", s"update($pos, $value)", "Object", "Ref", - s"${ev.value} = new $genericArrayData(new Object[$size]);") + // we cast byte/short to int when writing to the hash set. + case ByteType | ShortType | IntegerType => ("$mcI$sp", "Int") + case LongType => ("$mcJ$sp", ptName) + case FloatType => ("$mcF$sp", ptName) + case DoubleType => ("$mcD$sp", ptName) } } - protected def nonNullArrayDataBuild( - mayUseUnsafeArray: Boolean, - ev : ExprCode, - builder : String, - size : String) : String = { - val genericArrayData = classOf[GenericArrayData].getName - val build = if (mayUseUnsafeArray) { - val defaultSize = elementType.defaultSize - s""" - |if (!UnsafeArrayData.shouldUseGenericArrayData($defaultSize, $size)) { - | ${ev.value} = UnsafeArrayData.fromPrimitiveArray($builder.result()); - |} else { - | ${ev.value} = new $genericArrayData($builder.result()); - |} - """.stripMargin - } else { - s"${ev.value} = new $genericArrayData($builder.result());" - } - s""" - |if ($size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { - | throw new RuntimeException("Unsuccessful try create array with " + $size + - | " bytes of data due to exceeding the limit " + - | "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH} elements for GenericArrayData." + - | " $prettyName failed."); - |} - |$build - """.stripMargin + // we cast byte/short to int when writing to the hash set. + @transient protected lazy val hsValueCast = elementType match { + case ByteType | ShortType => "(int) " + case _ => "" } - protected def buildResultArrayData( - nonNullArrayDataBuild : String, - arrayDataBuilder : String, - ev : ExprCode, - builder : String, - setter : String, - javaTypeName : String, - array : String, - value : String, - nullElementIndex : String, - size : String, - i : String, - pos : String) : String = { + // When hitting a null value, put a null holder in the ArrayBuilder. Finally we will + // convert ArrayBuilder to ArrayData and setNull on the slot with null holder. + @transient protected lazy val nullValueHolder = elementType match { + case ByteType => "(byte) 0" + case ShortType => "(short) 0" + case _ => "0" + } + + protected def withResultArrayNullCheck( + body: String, + value: String, + nullElementIndex: String): String = { if (dataType.asInstanceOf[ArrayType].containsNull) { s""" - |if ($nullElementIndex < 0) { - | // result has no null element - | $nonNullArrayDataBuild - |} else { + |$body + |if ($nullElementIndex >= 0) { | // result has null element - | $arrayDataBuilder - | $javaTypeName[] $array = $builder.result(); - | for (int $i = 0, $pos = 0; $pos < $size; $pos++) { - | if ($pos == $nullElementIndex) { - | ${ev.value}.setNullAt($pos); - | } else { - | $javaTypeName $value = $array[$i++]; - | ${ev.value}.$setter; - | } - | } + | $value.setNullAt($nullElementIndex); |} """.stripMargin } else { - nonNullArrayDataBuild + body } } + + def buildResultArray( + builder: String, + value : String, + size : String, + nullElementIndex : String): String = withResultArrayNullCheck( + s""" + |if ($size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | throw new RuntimeException("Cannot create array with " + $size + + | " bytes of data due to exceeding the limit " + + | "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH} elements for ArrayData."); + |} + | + |if (!UnsafeArrayData.shouldUseGenericArrayData(${elementType.defaultSize}, $size)) { + | $value = UnsafeArrayData.fromPrimitiveArray($builder.result()); + |} else { + | $value = new ${classOf[GenericArrayData].getName}($builder.result()); + |} + """.stripMargin, value, nullElementIndex) } object ArraySetLike { @@ -4086,10 +4048,14 @@ object ArrayUnion { array(1, 3) """, since = "2.4.0") -case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetLike { - override def dataType: DataType = ArrayType(elementType, - left.dataType.asInstanceOf[ArrayType].containsNull && - right.dataType.asInstanceOf[ArrayType].containsNull) +case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetLike + with ComplexTypeMergingExpression { + override def dataType: DataType = { + dataTypeCheck + ArrayType(elementType, + left.dataType.asInstanceOf[ArrayType].containsNull && + right.dataType.asInstanceOf[ArrayType].containsNull) + } @transient lazy val evalIntersect: (ArrayData, ArrayData) => ArrayData = { if (elementTypeSupportEquals) { @@ -4182,100 +4148,115 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetL override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val arrayData = classOf[ArrayData].getName val i = ctx.freshName("i") - val pos = ctx.freshName("pos") val value = ctx.freshName("value") - val hsValue = ctx.freshName("hsValue") val size = ctx.freshName("size") - if (elementTypeSupportEquals) { - val unsafeArray = ctx.freshName("unsafeArray") - val (postFix, openHashElementType, hsJavaTypeName, genHsValue, - getter, setter, javaTypeName, primitiveTypeName, arrayDataBuilder) = - makeAccessors(ctx, ev, unsafeArray, size, value, pos, i) + if (canUseSpecializedHashSet) { + val jt = CodeGenerator.javaType(elementType) + val ptName = CodeGenerator.primitiveTypeName(jt) nullSafeCodeGen(ctx, ev, (array1, array2) => { val foundNullElement = ctx.freshName("foundNullElement") val nullElementIndex = ctx.freshName("nullElementIndex") val builder = ctx.freshName("builder") - val array = ctx.freshName("array") val openHashSet = classOf[OpenHashSet[_]].getName - val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$openHashElementType()" - val hs = ctx.freshName("hs") - val hsResult = ctx.freshName("hsResult") + val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()" + val hashSet = ctx.freshName("hashSet") + val hashSetResult = ctx.freshName("hashSetResult") val arrayBuilder = "scala.collection.mutable.ArrayBuilder" - val arrayBuilderClass = s"$arrayBuilder$$of$primitiveTypeName" - val arrayBuilderClassTag = if (primitiveTypeName != "Ref") { - s"scala.reflect.ClassTag$$.MODULE$$.$primitiveTypeName()" - } else { - s"scala.reflect.ClassTag$$.MODULE$$.AnyRef()" - } + val arrayBuilderClass = s"$arrayBuilder$$of$ptName" + val arrayBuilderClassTag = s"scala.reflect.ClassTag$$.MODULE$$.$ptName()" - def withArray2NullCheck(body: String) = + def withArray2NullCheck(body: String): String = if (right.dataType.asInstanceOf[ArrayType].containsNull) { - s""" - |if ($array2.isNullAt($i)) { - | $foundNullElement = true; - |} else { - | $body - |} - """.stripMargin + if (left.dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($array2.isNullAt($i)) { + | $foundNullElement = true; + |} else { + | $body + |} + """.stripMargin + } else { + // if array1's element is not nullable, we don't need to track the null element index. + s""" + |if (!$array2.isNullAt($i)) { + | $body + |} + """.stripMargin + } } else { body } - val array2Body = + + val writeArray2ToHashSet = withArray2NullCheck( s""" - |$javaTypeName $value = $array2.$getter; - |$hsJavaTypeName $hsValue = $genHsValue; - |$hs.add$postFix($hsValue); - """.stripMargin + |$jt $value = ${genGetValue(array2, i)}; + |$hashSet.add$hsPostFix($hsValueCast$value); + """.stripMargin) def withArray1NullAssignment(body: String) = if (left.dataType.asInstanceOf[ArrayType].containsNull) { - s""" - |if ($array1.isNullAt($i)) { - | if ($foundNullElement) { - | $nullElementIndex = $size; - | $foundNullElement = false; - | $size++; - | } - |} else { - | $body - |} - """.stripMargin + if (right.dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($array1.isNullAt($i)) { + | if ($foundNullElement) { + | $nullElementIndex = $size; + | $foundNullElement = false; + | $size++; + | $builder.$$plus$$eq($nullValueHolder); + | } + |} else { + | $body + |} + """.stripMargin + } else { + s""" + |if (!$array1.isNullAt($i)) { + | $body + |} + """.stripMargin + } } else { body } - val array1Body = + + val processArray1 = withArray1NullAssignment( s""" - |$javaTypeName $value = $array1.$getter; - |$hsJavaTypeName $hsValue = $genHsValue; - |if ($hs.contains($hsValue) && !$hsResult.contains($hsValue)) { + |$jt $value = ${genGetValue(array1, i)}; + |if ($hashSet.contains($hsValueCast$value) && + | !$hashSetResult.contains($hsValueCast$value)) { | if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { | break; | } - | $hsResult.add$postFix($hsValue); + | $hashSetResult.add$hsPostFix($hsValueCast$value); | $builder.$$plus$$eq($value); |} + """.stripMargin) + + // Only need to track null element index when result array's element is nullable. + val declareNullTrackVariables = if (dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |boolean $foundNullElement = false; + |int $nullElementIndex = -1; """.stripMargin + } else { + "" + } - val nonNullArrayData = nonNullArrayDataBuild(postFix != "", ev, builder, size) - val resultArrayData = buildResultArrayData( - nonNullArrayData, arrayDataBuilder, ev, builder, setter, - javaTypeName, array, value, nullElementIndex, size, i, pos) s""" - |$openHashSet $hs = new $openHashSet$postFix($classTag); - |$openHashSet $hsResult = new $openHashSet$postFix($classTag); - |boolean $foundNullElement = false; + |$openHashSet $hashSet = new $openHashSet$hsPostFix($classTag); + |$openHashSet $hashSetResult = new $openHashSet$hsPostFix($classTag); + |$declareNullTrackVariables |for (int $i = 0; $i < $array2.numElements(); $i++) { - | ${withArray2NullCheck(array2Body)} + | $writeArray2ToHashSet |} |$arrayBuilderClass $builder = | ($arrayBuilderClass)$arrayBuilder.make($arrayBuilderClassTag); - |int $nullElementIndex = -1; |int $size = 0; |for (int $i = 0; $i < $array1.numElements(); $i++) { - | ${withArray1NullAssignment(array1Body)} + | $processArray1 |} - |$resultArrayData + |${buildResultArray(builder, ev.value, size, nullElementIndex)} """.stripMargin }) } else { @@ -4404,31 +4385,10 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike val i = ctx.freshName("i") val value = ctx.freshName("value") val size = ctx.freshName("size") - val canUseSpecializedHashSet = elementType match { - case ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => true - case _ => false - } if (canUseSpecializedHashSet) { val jt = CodeGenerator.javaType(elementType) val ptName = CodeGenerator.primitiveTypeName(jt) - def genGetValue(array: String): String = - CodeGenerator.getValue(array, elementType, i) - - val (hsPostFix, hsTypeName) = elementType match { - // we cast byte/short to int when writing to the hash set. - case ByteType | ShortType | IntegerType => ("$mcI$sp", "Int") - case LongType => ("$mcJ$sp", ptName) - case FloatType => ("$mcF$sp", ptName) - case DoubleType => ("$mcD$sp", ptName) - } - - // we cast byte/short to int when writing to the hash set. - val hsValueCast = elementType match { - case ByteType | ShortType => "(int) " - case _ => "" - } - nullSafeCodeGen(ctx, ev, (array1, array2) => { val notFoundNullElement = ctx.freshName("notFoundNullElement") val nullElementIndex = ctx.freshName("nullElementIndex") @@ -4436,7 +4396,6 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike val openHashSet = classOf[OpenHashSet[_]].getName val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()" val hashSet = ctx.freshName("hashSet") - val genericArrayData = classOf[GenericArrayData].getName val arrayBuilder = "scala.collection.mutable.ArrayBuilder" val arrayBuilderClass = s"$arrayBuilder$$of$ptName" val arrayBuilderClassTag = s"scala.reflect.ClassTag$$.MODULE$$.$ptName()" @@ -4465,18 +4424,10 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike val writeArray2ToHashSet = withArray2NullCheck( s""" - |$jt $value = ${genGetValue(array2)}; + |$jt $value = ${genGetValue(array2, i)}; |$hashSet.add$hsPostFix($hsValueCast$value); """.stripMargin) - // When hitting a null value, put a null holder in the ArrayBuilder. Finally we will - // convert ArrayBuilder to ArrayData and setNull on the slot with null holder. - val nullValueHolder = elementType match { - case ByteType => "(byte) 0" - case ShortType => "(short) 0" - case _ => "0" - } - def withArray1NullAssignment(body: String) = if (left.dataType.asInstanceOf[ArrayType].containsNull) { s""" @@ -4497,7 +4448,7 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike val processArray1 = withArray1NullAssignment( s""" - |$jt $value = ${genGetValue(array1)}; + |$jt $value = ${genGetValue(array1, i)}; |if (!$hashSet.contains($hsValueCast$value)) { | if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { | break; @@ -4507,35 +4458,6 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike |} """.stripMargin) - def withResultArrayNullCheck(body: String): String = { - if (dataType.asInstanceOf[ArrayType].containsNull) { - s""" - |$body - |if ($nullElementIndex >= 0) { - | // result has null element - | ${ev.value}.setNullAt($nullElementIndex); - |} - """.stripMargin - } else { - body - } - } - - val buildResultArray = withResultArrayNullCheck( - s""" - |if ($size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { - | throw new RuntimeException("Cannot create array with " + $size + - | " bytes of data due to exceeding the limit " + - | "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH} elements for ArrayData."); - |} - | - |if (!UnsafeArrayData.shouldUseGenericArrayData(${elementType.defaultSize}, $size)) { - | ${ev.value} = UnsafeArrayData.fromPrimitiveArray($builder.result()); - |} else { - | ${ev.value} = new $genericArrayData($builder.result()); - |} - """.stripMargin) - // Only need to track null element index when array1's element is nullable. val declareNullTrackVariables = if (left.dataType.asInstanceOf[ArrayType].containsNull) { s""" @@ -4558,7 +4480,7 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike |for (int $i = 0; $i < $array1.numElements(); $i++) { | $processArray1 |} - |$buildResultArray + |${buildResultArray(builder, ev.value, size, nullElementIndex)} """.stripMargin }) } else { From ce755e2b049ca000d6da754654b792e181e6d904 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 6 Aug 2018 06:59:25 +0100 Subject: [PATCH 20/21] address review comments --- .../expressions/collectionOperations.scala | 144 +++++++++--------- .../spark/sql/DataFrameFunctionsSuite.scala | 12 +- 2 files changed, 80 insertions(+), 76 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index ccdd8aaa96aa1..f0f478642b3e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -4043,7 +4043,7 @@ object ArrayUnion { array2, without duplicates. """, examples = """ - Examples:Fun + Examples: > SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5)); array(1, 3) """, @@ -4060,81 +4060,89 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetL @transient lazy val evalIntersect: (ArrayData, ArrayData) => ArrayData = { if (elementTypeSupportEquals) { (array1, array2) => - val hs = new OpenHashSet[Any] - val hsResult = new OpenHashSet[Any] - var foundNullElement = false - var i = 0 - while (i < array2.numElements()) { - if (array2.isNullAt(i)) { - foundNullElement = true - } else { - val elem = array2.get(i, elementType) - hs.add(elem) - } - i += 1 - } - val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] - i = 0 - while (i < array1.numElements()) { - if (array1.isNullAt(i)) { - if (foundNullElement) { - arrayBuffer += null - foundNullElement = false + if (array1.numElements() != 0 && array2.numElements() != 0) { + val hs = new OpenHashSet[Any] + val hsResult = new OpenHashSet[Any] + var foundNullElement = false + var i = 0 + while (i < array2.numElements()) { + if (array2.isNullAt(i)) { + foundNullElement = true + } else { + val elem = array2.get(i, elementType) + hs.add(elem) } - } else { - val elem = array1.get(i, elementType) - if (hs.contains(elem) && !hsResult.contains(elem)) { - arrayBuffer += elem - hsResult.add(elem) + i += 1 + } + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + i = 0 + while (i < array1.numElements()) { + if (array1.isNullAt(i)) { + if (foundNullElement) { + arrayBuffer += null + foundNullElement = false + } + } else { + val elem = array1.get(i, elementType) + if (hs.contains(elem) && !hsResult.contains(elem)) { + arrayBuffer += elem + hsResult.add(elem) + } } + i += 1 } - i += 1 + new GenericArrayData(arrayBuffer) + } else { + new GenericArrayData(Seq.empty) } - new GenericArrayData(arrayBuffer) } else { (array1, array2) => - val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] - var alreadySeenNull = false - var i = 0 - while (i < array1.numElements()) { - var found = false - val elem1 = array1.get(i, elementType) - if (array1.isNullAt(i)) { - if (!alreadySeenNull) { + if (array1.numElements() != 0 && array2.numElements() != 0) { + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + var alreadySeenNull = false + var i = 0 + while (i < array1.numElements()) { + var found = false + val elem1 = array1.get(i, elementType) + if (array1.isNullAt(i)) { + if (!alreadySeenNull) { + var j = 0 + while (!found && j < array2.numElements()) { + found = array2.isNullAt(j) + j += 1 + } + // array2 is scanned only once for null element + alreadySeenNull = true + } + } else { var j = 0 while (!found && j < array2.numElements()) { - found = array2.isNullAt(j) - j += 1 - } - // array2 is scanned only once for null element - alreadySeenNull = true - } - } else { - var j = 0 - while (!found && j < array2.numElements()) { - if (!array2.isNullAt(j)) { - val elem2 = array2.get(j, elementType) - if (ordering.equiv(elem1, elem2)) { - // check whether elem1 is already stored in arrayBuffer - var foundArrayBuffer = false - var k = 0 - while (!foundArrayBuffer && k < arrayBuffer.size) { - val va = arrayBuffer(k) - foundArrayBuffer = (va != null) && ordering.equiv(va, elem1) - k += 1 + if (!array2.isNullAt(j)) { + val elem2 = array2.get(j, elementType) + if (ordering.equiv(elem1, elem2)) { + // check whether elem1 is already stored in arrayBuffer + var foundArrayBuffer = false + var k = 0 + while (!foundArrayBuffer && k < arrayBuffer.size) { + val va = arrayBuffer(k) + foundArrayBuffer = (va != null) && ordering.equiv(va, elem1) + k += 1 + } + found = !foundArrayBuffer } - found = !foundArrayBuffer } + j += 1 } - j += 1 } + if (found) { + arrayBuffer += elem1 + } + i += 1 } - if (found) { - arrayBuffer += elem1 - } - i += 1 + new GenericArrayData(arrayBuffer) + } else { + new GenericArrayData(Seq.empty) } - new GenericArrayData(arrayBuffer) } } @@ -4162,9 +4170,8 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetL val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()" val hashSet = ctx.freshName("hashSet") val hashSetResult = ctx.freshName("hashSetResult") - val arrayBuilder = "scala.collection.mutable.ArrayBuilder" + val arrayBuilder = classOf[mutable.ArrayBuilder[_]].getName val arrayBuilderClass = s"$arrayBuilder$$of$ptName" - val arrayBuilderClassTag = s"scala.reflect.ClassTag$$.MODULE$$.$ptName()" def withArray2NullCheck(body: String): String = if (right.dataType.asInstanceOf[ArrayType].containsNull) { @@ -4250,8 +4257,7 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetL |for (int $i = 0; $i < $array2.numElements(); $i++) { | $writeArray2ToHashSet |} - |$arrayBuilderClass $builder = - | ($arrayBuilderClass)$arrayBuilder.make($arrayBuilderClassTag); + |$arrayBuilderClass $builder = new $arrayBuilderClass(); |int $size = 0; |for (int $i = 0; $i < $array1.numElements(); $i++) { | $processArray1 @@ -4396,9 +4402,8 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike val openHashSet = classOf[OpenHashSet[_]].getName val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()" val hashSet = ctx.freshName("hashSet") - val arrayBuilder = "scala.collection.mutable.ArrayBuilder" + val arrayBuilder = classOf[mutable.ArrayBuilder[_]].getName val arrayBuilderClass = s"$arrayBuilder$$of$ptName" - val arrayBuilderClassTag = s"scala.reflect.ClassTag$$.MODULE$$.$ptName()" def withArray2NullCheck(body: String): String = if (right.dataType.asInstanceOf[ArrayType].containsNull) { @@ -4474,8 +4479,7 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike |for (int $i = 0; $i < $array2.numElements(); $i++) { | $writeArray2ToHashSet |} - |$arrayBuilderClass $builder = - | ($arrayBuilderClass)$arrayBuilder.make($arrayBuilderClassTag); + |$arrayBuilderClass $builder = new $arrayBuilderClass(); |int $size = 0; |for (int $i = 0; $i < $array1.numElements(); $i++) { | $processArray1 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 011d2022035dc..40d5ee07cc60d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1679,26 +1679,26 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { val df6 = Seq((null, null)).toDF("a", "b") intercept[AnalysisException] { df6.select(array_intersect($"a", $"b")) - } + }.getMessage.contains("data type mismatch") intercept[AnalysisException] { df6.selectExpr("array_intersect(a, b)") - } + }.getMessage.contains("data type mismatch") val df7 = Seq((Array(1), Array("a"))).toDF("a", "b") intercept[AnalysisException] { df7.select(array_intersect($"a", $"b")) - } + }.getMessage.contains("data type mismatch") intercept[AnalysisException] { df7.selectExpr("array_intersect(a, b)") - } + }.getMessage.contains("data type mismatch") val df8 = Seq((null, Array("a"))).toDF("a", "b") intercept[AnalysisException] { df8.select(array_intersect($"a", $"b")) - } + }.getMessage.contains("data type mismatch") intercept[AnalysisException] { df8.selectExpr("array_intersect(a, b)") - } + }.getMessage.contains("data type mismatch") } test("transform function - array for primitive type not containing null") { From 33781b640ed447d9a73a93b63e1834dd9360e72a Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 6 Aug 2018 07:28:47 +0100 Subject: [PATCH 21/21] address review comment --- .../expressions/collectionOperations.scala | 4 ++-- .../spark/sql/DataFrameFunctionsSuite.scala | 24 +++++++++---------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index f0f478642b3e4..e385c2d9782e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -4093,7 +4093,7 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetL } new GenericArrayData(arrayBuffer) } else { - new GenericArrayData(Seq.empty) + new GenericArrayData(Array.emptyObjectArray) } } else { (array1, array2) => @@ -4141,7 +4141,7 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetL } new GenericArrayData(arrayBuffer) } else { - new GenericArrayData(Seq.empty) + new GenericArrayData(Array.emptyObjectArray) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 40d5ee07cc60d..2e6ef11a54d2c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1677,28 +1677,28 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(df5.selectExpr("array_intersect(a, b)"), ans5) val df6 = Seq((null, null)).toDF("a", "b") - intercept[AnalysisException] { + assert(intercept[AnalysisException] { df6.select(array_intersect($"a", $"b")) - }.getMessage.contains("data type mismatch") - intercept[AnalysisException] { + }.getMessage.contains("data type mismatch")) + assert(intercept[AnalysisException] { df6.selectExpr("array_intersect(a, b)") - }.getMessage.contains("data type mismatch") + }.getMessage.contains("data type mismatch")) val df7 = Seq((Array(1), Array("a"))).toDF("a", "b") - intercept[AnalysisException] { + assert(intercept[AnalysisException] { df7.select(array_intersect($"a", $"b")) - }.getMessage.contains("data type mismatch") - intercept[AnalysisException] { + }.getMessage.contains("data type mismatch")) + assert(intercept[AnalysisException] { df7.selectExpr("array_intersect(a, b)") - }.getMessage.contains("data type mismatch") + }.getMessage.contains("data type mismatch")) val df8 = Seq((null, Array("a"))).toDF("a", "b") - intercept[AnalysisException] { + assert(intercept[AnalysisException] { df8.select(array_intersect($"a", $"b")) - }.getMessage.contains("data type mismatch") - intercept[AnalysisException] { + }.getMessage.contains("data type mismatch")) + assert(intercept[AnalysisException] { df8.selectExpr("array_intersect(a, b)") - }.getMessage.contains("data type mismatch") + }.getMessage.contains("data type mismatch")) } test("transform function - array for primitive type not containing null") {