From 491631d17a7b11038df627f7e3cd3e62674f991b Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 18 Apr 2018 19:03:55 +0100 Subject: [PATCH 01/29] initial commit --- python/pyspark/sql/functions.py | 19 ++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 241 +++++++++++++++++- .../CollectionExpressionsSuite.scala | 49 +++- .../org/apache/spark/sql/functions.scala | 11 + .../spark/sql/DataFrameFunctionsSuite.scala | 79 +++++- 6 files changed, 394 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 0a88e482787ff..18ad0e0ca2e86 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2052,6 +2052,25 @@ def array_union(col1, col2): return Column(sc._jvm.functions.array_union(_to_java_column(col1), _to_java_column(col2))) +@ignore_unicode_prefix +@since(2.4) +def array_except(col1, col2): + """ + Collection function: returns an array of the elements in col1 but not in col2, + without duplicates. + + :param col1: name of column containing array + :param col2: name of column containing array + + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])]) + >>> df.select(array_except(df.c1, df.c2)).collect() + [Row(array_except(c1, c2)=[u'b']))] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_except(_to_java_column(col1), _to_java_column(col2))) + + @since(1.4) def explode(col): """Returns a new row for each element in the given array or map. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index adc4837276793..b8b311219ca8d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -414,6 +414,7 @@ object FunctionRegistry { expression[ArrayJoin]("array_join"), expression[ArrayPosition]("array_position"), expression[ArraySort]("array_sort"), + expression[ArrayExcept]("array_except"), expression[ArrayUnion]("array_union"), expression[CreateMap]("map"), expression[CreateNamedStruct]("named_struct"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index b1d91ffbe86e0..610e024a9f05e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -69,7 +69,6 @@ trait BinaryArrayExpressionWithImplicitCast extends BinaryExpression } } - /** * Given an array or map, returns total number of elements in it. */ @@ -954,7 +953,6 @@ case class MapFromEntries(child: Expression) extends UnaryExpression { override def prettyName: String = "map_from_entries" } - /** * Common base class for [[SortArray]] and [[ArraySort]]. */ @@ -1096,6 +1094,204 @@ object ArraySortLike { } } +abstract class ArraySetUtils extends BinaryExpression with ExpectsInputTypes { + val kindUnion = 1 + val kindIntersect = 2 + val kindExcept = 3 + def typeId: Int + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, ArrayType) + + override def checkInputDataTypes(): TypeCheckResult = { + val r = super.checkInputDataTypes() + if ((r == TypeCheckResult.TypeCheckSuccess) && + (left.dataType.asInstanceOf[ArrayType].elementType != + right.dataType.asInstanceOf[ArrayType].elementType)) { + TypeCheckResult.TypeCheckFailure("Element type in both arrays must be the same") + } else { + r + } + } + + override def dataType: DataType = left.dataType + + private def elementType = dataType.asInstanceOf[ArrayType].elementType + private def cn1 = left.dataType.asInstanceOf[ArrayType].containsNull + private def cn2 = right.dataType.asInstanceOf[ArrayType].containsNull + + override def nullSafeEval(input1: Any, input2: Any): Any = { + val ary1 = input1.asInstanceOf[ArrayData] + val ary2 = input2.asInstanceOf[ArrayData] + + if (!cn1 && !cn2) { + elementType match { + case IntegerType => + // avoid boxing of primitive int array elements + var hs: OpenHashSet[Int] = null + val hs1 = new OpenHashSet[Int] + var i = 0 + while (i < ary1.numElements()) { + hs1.add(ary1.getInt(i)) + i += 1 + } + if (typeId == kindUnion) { + i = 0 + while (i < ary2.numElements()) { + hs1.add(ary2.getInt(i)) + i += 1 + } + hs = hs1 + } else { + val c = typeId == kindIntersect + hs = new OpenHashSet[Int] + i = 0 + while (i < ary2.numElements()) { + val k = ary2.getInt(i) + if (hs1.contains(k) == c) { + hs.add(k) + } + i += 1 + } + } + UnsafeArrayData.fromPrimitiveArray(hs.iterator.toArray) + case LongType => + // avoid boxing of primitive long array elements + var hs: OpenHashSet[Long] = null + val hs1 = new OpenHashSet[Long] + var i = 0 + while (i < ary1.numElements()) { + hs1.add(ary1.getLong(i)) + i += 1 + } + if (typeId == kindUnion) { + i = 0 + while (i < ary2.numElements()) { + hs1.add(ary2.getLong(i)) + i += 1 + } + hs = hs1 + } else { + val c = typeId == kindIntersect + hs = new OpenHashSet[Long] + i = 0 + while (i < ary2.numElements()) { + val k = ary2.getLong(i) + if (hs1.contains(k) == c) { + hs.add(k) + } + i += 1 + } + } + UnsafeArrayData.fromPrimitiveArray(hs.iterator.toArray) + case _ => + var hs: OpenHashSet[Any] = null + val hs1 = new OpenHashSet[Any] + var i = 0 + while (i < ary1.numElements()) { + hs1.add(ary1.get(i, elementType)) + i += 1 + } + if (typeId == kindUnion) { + i = 0 + while (i < ary2.numElements()) { + hs1.add(ary2.get(i, elementType)) + i += 1 + } + hs = hs1 + } else { + val c = typeId == kindIntersect + hs = new OpenHashSet[Any] + i = 0 + while (i < ary2.numElements()) { + val k = ary2.get(i, elementType) + if (hs1.contains(k) == c) { + hs.add(k) + } + i += 1 + } + } + new GenericArrayData(hs.iterator.toArray) + } + } else { + if (typeId == kindUnion) { + ArraySetUtils.arrayUnion(ary1, ary2, elementType) + } else if (typeId == kindIntersect) { + ArraySetUtils.arrayIntersect(ary1, ary2, elementType) + } else { + ArraySetUtils.arrayExcept(ary2, ary1, elementType) + } + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val hs = ctx.freshName("hs") + val hs1 = ctx.freshName("hs1") + val i = ctx.freshName("i") + val ArraySetUtils = "org.apache.spark.sql.catalyst.expressions.ArraySetUtils" + val genericArrayData = classOf[GenericArrayData].getName + val unsafeArrayData = classOf[UnsafeArrayData].getName + val openHashSet = classOf[OpenHashSet[_]].getName + val et = s"org.apache.spark.sql.types.DataTypes.$elementType" + val (postFix, classTag, getter, arrayBuilder, castType) = if (!cn1 && !cn2) { + val ptName = CodeGenerator.primitiveTypeName(elementType) + elementType match { + case ByteType | ShortType | IntegerType => + (s"$$mcI$$sp", s"scala.reflect.ClassTag$$.MODULE$$.$ptName()", s"get$ptName($i)", + s"$unsafeArrayData.fromPrimitiveArray", CodeGenerator.javaType(elementType)) + case LongType => + (s"$$mcJ$$sp", s"scala.reflect.ClassTag$$.MODULE$$.$ptName()", s"get$ptName($i)", + s"$unsafeArrayData.fromPrimitiveArray", "long") + case _ => + ("", s"scala.reflect.ClassTag$$.MODULE$$.Object()", s"get($i, $et)", + s"new $genericArrayData", "Object") + } + } else { + ("", "", "", "", "") + } + + nullSafeCodeGen(ctx, ev, (ary1, ary2) => { + if (classTag != "") { + if (typeId == kindUnion) { + s""" + |$openHashSet $hs = new $openHashSet$postFix($classTag); + |for (int $i = 0; $i < $ary1.numElements(); $i++) { + | $hs.add$postFix($ary1.$getter); + |} + |for (int $i = 0; $i < $ary2.numElements(); $i++) { + | $hs.add$postFix($ary2.$getter); + |} + |${ev.value} = $arrayBuilder(($castType[]) $hs.iterator().toArray($classTag)); + """.stripMargin + } else { + val condPrefix = if (typeId == kindIntersect) "" else "!" + s""" + |$openHashSet $hs1 = new $openHashSet$postFix($classTag); + |for (int $i = 0; $i < $ary1.numElements(); $i++) { + | $hs1.add$postFix($ary1.$getter); + |} + |$openHashSet $hs = new $openHashSet$postFix($classTag); + |for (int $i = 0; $i < $ary2.numElements(); $i++) { + | if ($condPrefix$hs1.contains$postFix($ary2.$getter)) { + | $hs.add$postFix($ary2.$getter); + | } + |} + |${ev.value} = $arrayBuilder(($castType[]) $hs.iterator().toArray($classTag)); + """.stripMargin + } + } else { + val setOp = if (typeId == kindUnion) { + "Union" + } else if (typeId == kindIntersect) { + "Intersect" + } else { + "Except" + } + s"${ev.value} = $ArraySetUtils$$.MODULE$$.array$setOp($ary2, $ary1, $et);" + } + }) + } +} + /** * Sorts the input array in ascending / descending order according to the natural ordering of * the array elements and returns it. @@ -2372,7 +2568,7 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio ByteArray.concat(inputs: _*) case StringType => val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) - UTF8String.concat(inputs : _*) + UTF8String.concat(inputs: _*) case ArrayType(elementType, _) => val inputs = children.toStream.map(_.eval(input)) if (inputs.contains(null)) { @@ -3968,3 +4164,42 @@ object ArrayUnion { new GenericArrayData(arrayBuffer) } } + +object ArraySetUtils { + def arrayUnion(array1: ArrayData, array2: ArrayData, et: DataType): ArrayData = { + new GenericArrayData(array1.toArray[AnyRef](et).union(array2.toArray[AnyRef](et)) + .distinct.asInstanceOf[Array[Any]]) + } + + def arrayIntersect(array1: ArrayData, array2: ArrayData, et: DataType): ArrayData = { + new GenericArrayData(array1.toArray[AnyRef](et).intersect(array2.toArray[AnyRef](et)) + .distinct.asInstanceOf[Array[Any]]) + } + + def arrayExcept(array1: ArrayData, array2: ArrayData, et: DataType): ArrayData = { + new GenericArrayData(array1.toArray[AnyRef](et).diff(array2.toArray[AnyRef](et)) + .distinct.asInstanceOf[Array[Any]]) + } +} + +/** + * Returns an array of the elements in the union of x and y, without duplicates + */ +@ExpressionDescription( + usage = """ + _FUNC_(array1, array2) - Returns an array of the elements in array1 but not in array2, + without duplicates. The order of elements in the result is not determined. + """, + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5)); + array(2) + """, + since = "2.4.0") +case class ArrayExcept(array1: Expression, array2: Expression) extends ArraySetUtils { + override def typeId: Int = kindExcept + override def left: Expression = array2 + override def right: Expression = array1 + + override def prettyName: String = "array_except" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 5c5728548e646..b2ec8ea925c97 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1032,7 +1032,9 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper intercept[Exception] { checkEvaluation(ElementAt(a0, Literal(0)), null) }.getMessage.contains("SQL array indices start at 1") - intercept[Exception] { checkEvaluation(ElementAt(a0, Literal(1.1)), null) } + intercept[Exception] { + checkEvaluation(ElementAt(a0, Literal(1.1)), null) + } checkEvaluation(ElementAt(a0, Literal(4)), null) checkEvaluation(ElementAt(a0, Literal(-4)), null) @@ -1503,4 +1505,49 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper assert(!shuffle.fastEquals(shuffle.freshCopy())) assert(!shuffle.fastEquals(Shuffle(ai0, seed2))) } + + test("Array Except") { + val a00 = Literal.create(Seq(1, 2, 4), ArrayType(IntegerType, false)) + val a01 = Literal.create(Seq(4, 2), ArrayType(IntegerType, false)) + val a02 = Literal.create(Seq(1, 2, 4), ArrayType(IntegerType)) + val a03 = Literal.create(Seq(1, 2, null, 4, 5), ArrayType(IntegerType)) + val a04 = Literal.create(Seq(-5, 4, null, 2, -1), ArrayType(IntegerType)) + val a05 = Literal.create(Seq.empty[Int], ArrayType(IntegerType)) + + val a10 = Literal.create(Seq(1L, 2L, 4L), ArrayType(LongType, false)) + val a11 = Literal.create(Seq(4L, 2L), ArrayType(LongType, false)) + val a12 = Literal.create(Seq(1L, 2L, 4L), ArrayType(LongType)) + val a13 = Literal.create(Seq(1L, 2L, null, 4L, 5L), ArrayType(LongType)) + val a14 = Literal.create(Seq(-5L, 4L, null, 2L, -1L), ArrayType(LongType)) + val a15 = Literal.create(Seq.empty[Long], ArrayType(LongType)) + + val a20 = Literal.create(Seq("b", "a", "c"), ArrayType(StringType)) + val a21 = Literal.create(Seq("c", null, "a", "f"), ArrayType(StringType)) + val a22 = Literal.create(Seq("b", null, "a", "g"), ArrayType(StringType)) + val a23 = Literal.create(Seq("b", "a", "c"), ArrayType(StringType, false)) + val a24 = Literal.create(Seq("c", "d", "a", "f"), ArrayType(StringType, false)) + + val a30 = Literal.create(Seq(null, null), ArrayType(NullType)) + + checkEvaluation(ArrayExcept(a00, a01), UnsafeArrayData.fromPrimitiveArray(Array(1))) + checkEvaluation(ArrayExcept(a02, a01), Seq(1)) + checkEvaluation(ArrayExcept(a03, a02), Seq(null, 5)) + checkEvaluation(ArrayExcept(a03, a04), Seq(1, 5)) + checkEvaluation(ArrayExcept(a03, a05), Seq(1, 2, null, 4, 5)) + checkEvaluation(ArrayExcept(a05, a03), Seq.empty) + + checkEvaluation(ArrayExcept(a10, a11), UnsafeArrayData.fromPrimitiveArray(Array(1L))) + checkEvaluation(ArrayExcept(a12, a11), Seq(1L)) + checkEvaluation(ArrayExcept(a13, a12), Seq(null, 5L)) + checkEvaluation(ArrayExcept(a13, a14), Seq(1L, 5L)) + checkEvaluation(ArrayExcept(a13, a15), Seq(1L, 2L, null, 4L, 5L)) + checkEvaluation(ArrayExcept(a15, a13), Seq.empty) + + checkEvaluation(ArrayExcept(a20, a21), Seq("b")) + checkEvaluation(ArrayExcept(a21, a22), Seq("c", "f")) + checkEvaluation(ArrayExcept(a22, a23), Seq(null, "g")) + checkEvaluation(ArrayExcept(a23, a24), Seq("b")) + + checkEvaluation(ArrayExcept(a30, a30), Seq.empty) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index bcd0c946ab996..760b2219cb888 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3229,6 +3229,17 @@ object functions { ArrayUnion(col1.expr, col2.expr) } + /** + * Returns an array of the elements in the first array but not in the second array, + * without duplicates. The order of elements in the result is not determined + * + * @group collection_funcs + * @since 2.4.0 + */ + def array_except(col1: Column, col2: Column): Column = withExpr { + ArrayExcept(col1.expr, col2.expr) + } + /** * Creates a new row for each element in the given array or map column. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 299c96f74af22..2241cdbf6f2ac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1173,8 +1173,8 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } test("concat function - arrays") { - val nseqi : Seq[Int] = null - val nseqs : Seq[String] = null + val nseqi: Seq[Int] = null + val nseqs: Seq[String] = null val df = Seq( (Seq(1), Seq(2, 3), Seq(5L, 6L), nseqi, Seq("a", "b", "c"), Seq("d", "e"), Seq("f"), nseqs), (Seq(1, 0), Seq.empty[Int], Seq(2L), nseqi, Seq("a"), Seq.empty[String], Seq(null), nseqs) @@ -1204,11 +1204,42 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } +<<<<<<< HEAD // Test with local relation, the Project will be evaluated without codegen simpleTest() // Test with cached relation, the Project will be evaluated with codegen df.cache() simpleTest() +======= + checkAnswer( + df.select(concat($"i1", $"s1")), + Seq(Row(Seq("1", "a", "b", "c")), Row(Seq("1", "0", "a"))) + ) + checkAnswer( + df.select(concat($"i1", $"i2", $"i3")), + Seq(Row(Seq(1, 2, 3, 5, 6)), Row(Seq(1, 0, 2))) + ) + checkAnswer( + df.filter(dummyFilter($"i1")).select(concat($"i1", $"i2", $"i3")), + Seq(Row(Seq(1, 2, 3, 5, 6)), Row(Seq(1, 0, 2))) + ) + checkAnswer( + df.selectExpr("concat(array(1, null), i2, i3)"), + Seq(Row(Seq(1, null, 2, 3, 5, 6)), Row(Seq(1, null, 2))) + ) + checkAnswer( + df.select(concat($"s1", $"s2", $"s3")), + Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null))) + ) + checkAnswer( + df.selectExpr("concat(s1, s2, s3)"), + Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null))) + ) + checkAnswer( + df.filter(dummyFilter($"s1")) select (concat($"s1", $"s2", $"s3")), + Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null))) + ) +>>>>>>> initial commit // Null test cases def nullTest(): Unit = { @@ -1513,6 +1544,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } +<<<<<<< HEAD // Shuffle expressions should produce same results at retries in the same DataFrame. private def checkShuffleResult(df: DataFrame): Unit = { checkAnswer(df, df.collect()) @@ -1578,6 +1610,49 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { testNonPrimitiveType() } + test("array_except functions") { + val df1 = Seq((Array(1, 2, 4), Array(4, 2))).toDF("a", "b") + val ans1 = Row(Seq(1)) + checkAnswer(df1.select(array_except($"a", $"b")), ans1) + checkAnswer(df1.selectExpr("array_except(a, b)"), ans1) + + val df2 = Seq((Array[Integer](1, 2, null, 4, 5), Array[Integer](-5, 4, null, 2, -1))) + .toDF("a", "b") + val ans2 = Row(Seq(1, 5)) + checkAnswer(df2.select(array_except($"a", $"b")), ans2) + checkAnswer(df2.selectExpr("array_except(a, b)"), ans2) + + val df3 = Seq((Array(1L, 2L, 4L), Array(4L, 2L))).toDF("a", "b") + val ans3 = Row(Seq(1L)) + checkAnswer(df3.select(array_except($"a", $"b")), ans3) + checkAnswer(df3.selectExpr("array_except(a, b)"), ans3) + + val df4 = Seq( + (Array[java.lang.Long](1L, 2L, null, 4L, 5L), Array[java.lang.Long](-5L, 4L, null, 2L, -1L))) + .toDF("a", "b") + val ans4 = Row(Seq(1L, 5L)) + checkAnswer(df4.select(array_except($"a", $"b")), ans4) + checkAnswer(df4.selectExpr("array_except(a, b)"), ans4) + + val df5 = Seq((Array("c", null, "a", "f"), Array("b", null, "a", "g"))).toDF("a", "b") + val ans5 = Row(Seq("c", "f")) + checkAnswer(df5.select(array_except($"a", $"b")), ans5) + checkAnswer(df5.selectExpr("array_except(a, b)"), ans5) + + val df6 = Seq((null, null)).toDF("a", "b") + val ans6 = Row(null) + checkAnswer(df6.select(array_except($"a", $"b")), ans6) + checkAnswer(df6.selectExpr("array_except(a, b)"), ans6) + + val df0 = Seq((Array(1), Array("a"))).toDF("a", "b") + intercept[AnalysisException] { + df0.select(array_except($"a", $"b")) + } + intercept[AnalysisException] { + df0.selectExpr("array_except(a, b)") + } + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { From e7469487b8cfccd1e3ef2617ac5de498b4b358e5 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 19 Apr 2018 03:12:25 +0100 Subject: [PATCH 02/29] fix pyspark test failure --- python/pyspark/sql/functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 18ad0e0ca2e86..778fa787ed8ca 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2065,7 +2065,7 @@ def array_except(col1, col2): >>> from pyspark.sql import Row >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])]) >>> df.select(array_except(df.c1, df.c2)).collect() - [Row(array_except(c1, c2)=[u'b']))] + [Row(array_except(c1, c2)=[u'b'])] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.array_except(_to_java_column(col1), _to_java_column(col2))) From a0bc35664fe70d53f1d3ca6a3fb0ebbca466cf01 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 19 Apr 2018 19:57:42 +0100 Subject: [PATCH 03/29] address review comments in #21061 --- .../expressions/collectionOperations.scala | 427 ++++++++++-------- .../CollectionExpressionsSuite.scala | 2 + .../spark/sql/DataFrameFunctionsSuite.scala | 20 +- 3 files changed, 247 insertions(+), 202 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 610e024a9f05e..dd36376b4d595 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1094,204 +1094,6 @@ object ArraySortLike { } } -abstract class ArraySetUtils extends BinaryExpression with ExpectsInputTypes { - val kindUnion = 1 - val kindIntersect = 2 - val kindExcept = 3 - def typeId: Int - - override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, ArrayType) - - override def checkInputDataTypes(): TypeCheckResult = { - val r = super.checkInputDataTypes() - if ((r == TypeCheckResult.TypeCheckSuccess) && - (left.dataType.asInstanceOf[ArrayType].elementType != - right.dataType.asInstanceOf[ArrayType].elementType)) { - TypeCheckResult.TypeCheckFailure("Element type in both arrays must be the same") - } else { - r - } - } - - override def dataType: DataType = left.dataType - - private def elementType = dataType.asInstanceOf[ArrayType].elementType - private def cn1 = left.dataType.asInstanceOf[ArrayType].containsNull - private def cn2 = right.dataType.asInstanceOf[ArrayType].containsNull - - override def nullSafeEval(input1: Any, input2: Any): Any = { - val ary1 = input1.asInstanceOf[ArrayData] - val ary2 = input2.asInstanceOf[ArrayData] - - if (!cn1 && !cn2) { - elementType match { - case IntegerType => - // avoid boxing of primitive int array elements - var hs: OpenHashSet[Int] = null - val hs1 = new OpenHashSet[Int] - var i = 0 - while (i < ary1.numElements()) { - hs1.add(ary1.getInt(i)) - i += 1 - } - if (typeId == kindUnion) { - i = 0 - while (i < ary2.numElements()) { - hs1.add(ary2.getInt(i)) - i += 1 - } - hs = hs1 - } else { - val c = typeId == kindIntersect - hs = new OpenHashSet[Int] - i = 0 - while (i < ary2.numElements()) { - val k = ary2.getInt(i) - if (hs1.contains(k) == c) { - hs.add(k) - } - i += 1 - } - } - UnsafeArrayData.fromPrimitiveArray(hs.iterator.toArray) - case LongType => - // avoid boxing of primitive long array elements - var hs: OpenHashSet[Long] = null - val hs1 = new OpenHashSet[Long] - var i = 0 - while (i < ary1.numElements()) { - hs1.add(ary1.getLong(i)) - i += 1 - } - if (typeId == kindUnion) { - i = 0 - while (i < ary2.numElements()) { - hs1.add(ary2.getLong(i)) - i += 1 - } - hs = hs1 - } else { - val c = typeId == kindIntersect - hs = new OpenHashSet[Long] - i = 0 - while (i < ary2.numElements()) { - val k = ary2.getLong(i) - if (hs1.contains(k) == c) { - hs.add(k) - } - i += 1 - } - } - UnsafeArrayData.fromPrimitiveArray(hs.iterator.toArray) - case _ => - var hs: OpenHashSet[Any] = null - val hs1 = new OpenHashSet[Any] - var i = 0 - while (i < ary1.numElements()) { - hs1.add(ary1.get(i, elementType)) - i += 1 - } - if (typeId == kindUnion) { - i = 0 - while (i < ary2.numElements()) { - hs1.add(ary2.get(i, elementType)) - i += 1 - } - hs = hs1 - } else { - val c = typeId == kindIntersect - hs = new OpenHashSet[Any] - i = 0 - while (i < ary2.numElements()) { - val k = ary2.get(i, elementType) - if (hs1.contains(k) == c) { - hs.add(k) - } - i += 1 - } - } - new GenericArrayData(hs.iterator.toArray) - } - } else { - if (typeId == kindUnion) { - ArraySetUtils.arrayUnion(ary1, ary2, elementType) - } else if (typeId == kindIntersect) { - ArraySetUtils.arrayIntersect(ary1, ary2, elementType) - } else { - ArraySetUtils.arrayExcept(ary2, ary1, elementType) - } - } - } - - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val hs = ctx.freshName("hs") - val hs1 = ctx.freshName("hs1") - val i = ctx.freshName("i") - val ArraySetUtils = "org.apache.spark.sql.catalyst.expressions.ArraySetUtils" - val genericArrayData = classOf[GenericArrayData].getName - val unsafeArrayData = classOf[UnsafeArrayData].getName - val openHashSet = classOf[OpenHashSet[_]].getName - val et = s"org.apache.spark.sql.types.DataTypes.$elementType" - val (postFix, classTag, getter, arrayBuilder, castType) = if (!cn1 && !cn2) { - val ptName = CodeGenerator.primitiveTypeName(elementType) - elementType match { - case ByteType | ShortType | IntegerType => - (s"$$mcI$$sp", s"scala.reflect.ClassTag$$.MODULE$$.$ptName()", s"get$ptName($i)", - s"$unsafeArrayData.fromPrimitiveArray", CodeGenerator.javaType(elementType)) - case LongType => - (s"$$mcJ$$sp", s"scala.reflect.ClassTag$$.MODULE$$.$ptName()", s"get$ptName($i)", - s"$unsafeArrayData.fromPrimitiveArray", "long") - case _ => - ("", s"scala.reflect.ClassTag$$.MODULE$$.Object()", s"get($i, $et)", - s"new $genericArrayData", "Object") - } - } else { - ("", "", "", "", "") - } - - nullSafeCodeGen(ctx, ev, (ary1, ary2) => { - if (classTag != "") { - if (typeId == kindUnion) { - s""" - |$openHashSet $hs = new $openHashSet$postFix($classTag); - |for (int $i = 0; $i < $ary1.numElements(); $i++) { - | $hs.add$postFix($ary1.$getter); - |} - |for (int $i = 0; $i < $ary2.numElements(); $i++) { - | $hs.add$postFix($ary2.$getter); - |} - |${ev.value} = $arrayBuilder(($castType[]) $hs.iterator().toArray($classTag)); - """.stripMargin - } else { - val condPrefix = if (typeId == kindIntersect) "" else "!" - s""" - |$openHashSet $hs1 = new $openHashSet$postFix($classTag); - |for (int $i = 0; $i < $ary1.numElements(); $i++) { - | $hs1.add$postFix($ary1.$getter); - |} - |$openHashSet $hs = new $openHashSet$postFix($classTag); - |for (int $i = 0; $i < $ary2.numElements(); $i++) { - | if ($condPrefix$hs1.contains$postFix($ary2.$getter)) { - | $hs.add$postFix($ary2.$getter); - | } - |} - |${ev.value} = $arrayBuilder(($castType[]) $hs.iterator().toArray($classTag)); - """.stripMargin - } - } else { - val setOp = if (typeId == kindUnion) { - "Union" - } else if (typeId == kindIntersect) { - "Intersect" - } else { - "Except" - } - s"${ev.value} = $ArraySetUtils$$.MODULE$$.array$setOp($ary2, $ary1, $et);" - } - }) - } -} - /** * Sorts the input array in ascending / descending order according to the natural ordering of * the array elements and returns it. @@ -4166,6 +3968,34 @@ object ArrayUnion { } object ArraySetUtils { + val kindUnion = 1 + val kindIntersect = 2 + val kindExcept = 3 + + def toUnsafeIntArray(hs: OpenHashSet[Int]): UnsafeArrayData = { + val array = new Array[Int](hs.size) + var pos = hs.nextPos(0) + var i = 0 + while (pos != OpenHashSet.INVALID_POS) { + array(i) = hs.getValue(pos) + pos = hs.nextPos(pos + 1) + i += 1 + } + UnsafeArrayData.fromPrimitiveArray(array) + } + + def toUnsafeLongArray(hs: OpenHashSet[Long]): UnsafeArrayData = { + val array = new Array[Long](hs.size) + var pos = hs.nextPos(0) + var i = 0 + while (pos != OpenHashSet.INVALID_POS) { + array(i) = hs.getValue(pos) + pos = hs.nextPos(pos + 1) + i += 1 + } + UnsafeArrayData.fromPrimitiveArray(array) + } + def arrayUnion(array1: ArrayData, array2: ArrayData, et: DataType): ArrayData = { new GenericArrayData(array1.toArray[AnyRef](et).union(array2.toArray[AnyRef](et)) .distinct.asInstanceOf[Array[Any]]) @@ -4182,6 +4012,143 @@ object ArraySetUtils { } } +abstract class ArraySetUtils extends BinaryExpression with ExpectsInputTypes { + def typeId: Int + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, ArrayType) + + override def checkInputDataTypes(): TypeCheckResult = { + val r = super.checkInputDataTypes() + if ((r == TypeCheckResult.TypeCheckSuccess) && + (left.dataType.asInstanceOf[ArrayType].elementType != + right.dataType.asInstanceOf[ArrayType].elementType)) { + TypeCheckResult.TypeCheckFailure("Element type in both arrays must be the same") + } else { + r + } + } + + override def dataType: DataType = left.dataType + + private def elementType = dataType.asInstanceOf[ArrayType].elementType + private def cn = left.dataType.asInstanceOf[ArrayType].containsNull || + right.dataType.asInstanceOf[ArrayType].containsNull + + def intEval(ary: ArrayData, hs1: OpenHashSet[Int]): OpenHashSet[Int] + def longEval(ary: ArrayData, hs1: OpenHashSet[Long]): OpenHashSet[Long] + def genericEval(ary: ArrayData, hs1: OpenHashSet[Any], et: DataType): OpenHashSet[Any] + def codeGen(ctx: CodegenContext, hs1: String, hs: String, len: String, getter: String, i: String, + postFix: String, newOpenHashSet: String): String + + override def nullSafeEval(input1: Any, input2: Any): Any = { + val ary1 = input1.asInstanceOf[ArrayData] + val ary2 = input2.asInstanceOf[ArrayData] + + if (!cn) { + elementType match { + case IntegerType => + // avoid boxing of primitive int array elements + val hs1 = new OpenHashSet[Int] + var i = 0 + while (i < ary1.numElements()) { + hs1.add(ary1.getInt(i)) + i += 1 + } + ArraySetUtils.toUnsafeIntArray(intEval(ary2, hs1)) + case LongType => + // avoid boxing of primitive long array elements + val hs1 = new OpenHashSet[Long] + var i = 0 + while (i < ary1.numElements()) { + hs1.add(ary1.getLong(i)) + i += 1 + } + ArraySetUtils.toUnsafeLongArray(longEval(ary2, hs1)) + case _ => + var hs: OpenHashSet[Any] = null + val hs1 = new OpenHashSet[Any] + var i = 0 + while (i < ary1.numElements()) { + hs1.add(ary1.get(i, elementType)) + i += 1 + } + new GenericArrayData(genericEval(ary2, hs1, elementType).iterator.toArray) + } + } else { + if (typeId == ArraySetUtils.kindUnion) { + ArraySetUtils.arrayUnion(ary1, ary2, elementType) + } else if (typeId == ArraySetUtils.kindIntersect) { + ArraySetUtils.arrayIntersect(ary1, ary2, elementType) + } else { + ArraySetUtils.arrayExcept(ary2, ary1, elementType) + } + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val i = ctx.freshName("i") + val arraySetUtils = "org.apache.spark.sql.catalyst.expressions.ArraySetUtils" + val genericArrayData = classOf[GenericArrayData].getName + val unsafeArrayData = classOf[UnsafeArrayData].getName + val openHashSet = classOf[OpenHashSet[_]].getName + val et = s"org.apache.spark.sql.types.DataTypes.$elementType" + val (postFix, classTag, getter, arrayBuilder, javaTypeName) = if (!cn) { + val ptName = CodeGenerator.primitiveTypeName(elementType) + elementType match { + case ByteType | ShortType | IntegerType => + (s"$$mcI$$sp", s"scala.reflect.ClassTag$$.MODULE$$.$ptName()", s"get$ptName($i)", + s"$unsafeArrayData.fromPrimitiveArray", CodeGenerator.javaType(elementType)) + case LongType => + (s"$$mcJ$$sp", s"scala.reflect.ClassTag$$.MODULE$$.$ptName()", s"get$ptName($i)", + s"$unsafeArrayData.fromPrimitiveArray", "long") + case _ => + ("", s"scala.reflect.ClassTag$$.MODULE$$.Object()", s"get($i, $et)", + s"new $genericArrayData", "Object") + } + } else { + ("", "", "", "", "") + } + + val hs = ctx.freshName("hs") + val hs1 = ctx.freshName("hs1") + val invalidPos = ctx.freshName("invalidPos") + val pos = ctx.freshName("pos") + val ary = ctx.freshName("ary") + nullSafeCodeGen(ctx, ev, (ary1, ary2) => { + if (classTag != "") { + val secondLoop = codeGen(ctx, hs1, hs, s"$ary2.numElements()", s"$ary2.$getter", i, + postFix, s"new $openHashSet$postFix($classTag)") + s""" + |$openHashSet $hs1 = new $openHashSet$postFix($classTag); + |for (int $i = 0; $i < $ary1.numElements(); $i++) { + | $hs1.add$postFix($ary1.$getter); + |} + |$secondLoop + |$javaTypeName[] $ary = new $javaTypeName[$hs.size()]; + |int $invalidPos = $openHashSet.INVALID_POS(); + |int $pos = $hs.nextPos(0); + |int $i = 0; + |while ($pos != $invalidPos) { + | $ary[$i] = ($javaTypeName) $hs.getValue$postFix($pos); + | $pos = $hs.nextPos($pos + 1); + | $i++; + |} + |${ev.value} = $arrayBuilder($ary); + """.stripMargin + } else { + val setOp = if (typeId == ArraySetUtils.kindUnion) { + "Union" + } else if (typeId == ArraySetUtils.kindIntersect) { + "Intersect" + } else { + "Except" + } + s"${ev.value} = $arraySetUtils$$.MODULE$$.array$setOp($ary2, $ary1, $et);" + } + }) + } +} + /** * Returns an array of the elements in the union of x and y, without duplicates */ @@ -4197,9 +4164,71 @@ object ArraySetUtils { """, since = "2.4.0") case class ArrayExcept(array1: Expression, array2: Expression) extends ArraySetUtils { - override def typeId: Int = kindExcept + override def typeId: Int = ArraySetUtils.kindExcept override def left: Expression = array2 override def right: Expression = array1 + override def intEval(ary: ArrayData, hs1: OpenHashSet[Int]): OpenHashSet[Int] = { + val hs = new OpenHashSet[Int] + var i = 0 + while (i < ary.numElements()) { + val k = ary.getInt(i) + if (!hs1.contains(k)) { + hs.add(k) + } + i += 1 + } + hs + } + + override def longEval(ary: ArrayData, hs1: OpenHashSet[Long]): OpenHashSet[Long] = { + val hs = new OpenHashSet[Long] + var i = 0 + while (i < ary.numElements()) { + val k = ary.getLong(i) + if (!hs1.contains(k)) { + hs.add(k) + } + i += 1 + } + hs + } + + override def genericEval( + ary: ArrayData, + hs1: OpenHashSet[Any], + et: DataType): OpenHashSet[Any] = { + val hs = new OpenHashSet[Any] + var i = 0 + while (i < ary.numElements()) { + val k = ary.get(i, et) + if (!hs1.contains(k)) { + hs.add(k) + } + i += 1 + } + hs + } + + override def codeGen( + ctx: CodegenContext, + hs1: String, + hs: String, + len: String, + getter: String, + i: String, + postFix: String, + newOpenHashSet: String): String = { + val openHashSet = classOf[OpenHashSet[_]].getName + s""" + |$openHashSet $hs = $newOpenHashSet; + |for (int $i = 0; $i < $len; $i++) { + | if (!$hs1.contains$postFix($getter)) { + | $hs.add$postFix($getter); + | } + |} + """.stripMargin + } + override def prettyName: String = "array_except" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index b2ec8ea925c97..3b8f660c33e5a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1548,6 +1548,8 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayExcept(a22, a23), Seq(null, "g")) checkEvaluation(ArrayExcept(a23, a24), Seq("b")) + //checkEvaluation(ArrayExcept(a20, a30), Seq("b", "a", "c")) + //checkEvaluation(ArrayExcept(a30, a20), Seq(null)) checkEvaluation(ArrayExcept(a30, a30), Seq.empty) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 2241cdbf6f2ac..1979ccd4b77a7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1644,12 +1644,26 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(df6.select(array_except($"a", $"b")), ans6) checkAnswer(df6.selectExpr("array_except(a, b)"), ans6) - val df0 = Seq((Array(1), Array("a"))).toDF("a", "b") + val df7 = Seq((Array(1), Array("a"))).toDF("a", "b") intercept[AnalysisException] { - df0.select(array_except($"a", $"b")) + df7.select(array_except($"a", $"b")) } intercept[AnalysisException] { - df0.selectExpr("array_except(a, b)") + df7.selectExpr("array_except(a, b)") + } + val df8 = Seq((Array("a"), null)).toDF("a", "b") + intercept[AnalysisException] { + df8.select(array_except($"a", $"b")) + } + intercept[AnalysisException] { + df8.selectExpr("array_except(a, b)") + } + val df9 = Seq((null, Array("a"))).toDF("a", "b") + intercept[AnalysisException] { + df9.select(array_except($"a", $"b")) + } + intercept[AnalysisException] { + df9.selectExpr("array_except(a, b)") } } From 03344079a236dab812718ceda568127963438454 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 19 Apr 2018 20:26:46 +0100 Subject: [PATCH 04/29] fix style error --- .../sql/catalyst/expressions/CollectionExpressionsSuite.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 3b8f660c33e5a..b2ec8ea925c97 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1548,8 +1548,6 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayExcept(a22, a23), Seq(null, "g")) checkEvaluation(ArrayExcept(a23, a24), Seq("b")) - //checkEvaluation(ArrayExcept(a20, a30), Seq("b", "a", "c")) - //checkEvaluation(ArrayExcept(a30, a20), Seq(null)) checkEvaluation(ArrayExcept(a30, a30), Seq.empty) } } From 08c367093392610c3c49dad9071472a454b086b1 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 20 Apr 2018 04:43:32 +0100 Subject: [PATCH 05/29] fix pyspark test failure --- .../expressions/collectionOperations.scala | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index dd36376b4d595..6ed02574ed2bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -4014,25 +4014,27 @@ object ArraySetUtils { abstract class ArraySetUtils extends BinaryExpression with ExpectsInputTypes { def typeId: Int + def array1: Expression + def array2: Expression override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, ArrayType) override def checkInputDataTypes(): TypeCheckResult = { val r = super.checkInputDataTypes() if ((r == TypeCheckResult.TypeCheckSuccess) && - (left.dataType.asInstanceOf[ArrayType].elementType != - right.dataType.asInstanceOf[ArrayType].elementType)) { + (array1.dataType.asInstanceOf[ArrayType].elementType != + array2.dataType.asInstanceOf[ArrayType].elementType)) { TypeCheckResult.TypeCheckFailure("Element type in both arrays must be the same") } else { r } } - override def dataType: DataType = left.dataType + override def dataType: DataType = array1.dataType private def elementType = dataType.asInstanceOf[ArrayType].elementType - private def cn = left.dataType.asInstanceOf[ArrayType].containsNull || - right.dataType.asInstanceOf[ArrayType].containsNull + private def cn = array1.dataType.asInstanceOf[ArrayType].containsNull || + array2.dataType.asInstanceOf[ArrayType].containsNull def intEval(ary: ArrayData, hs1: OpenHashSet[Int]): OpenHashSet[Int] def longEval(ary: ArrayData, hs1: OpenHashSet[Long]): OpenHashSet[Long] @@ -4163,10 +4165,10 @@ abstract class ArraySetUtils extends BinaryExpression with ExpectsInputTypes { array(2) """, since = "2.4.0") -case class ArrayExcept(array1: Expression, array2: Expression) extends ArraySetUtils { +case class ArrayExcept(left: Expression, right: Expression) extends ArraySetUtils { override def typeId: Int = ArraySetUtils.kindExcept - override def left: Expression = array2 - override def right: Expression = array1 + override def array1: Expression = right + override def array2: Expression = left override def intEval(ary: ArrayData, hs1: OpenHashSet[Int]): OpenHashSet[Int] = { val hs = new OpenHashSet[Int] From d989fd75e0e8f8807bfd242bca114d3d7593b99b Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 20 Apr 2018 13:51:09 +0100 Subject: [PATCH 06/29] fix failures --- .../expressions/collectionOperations.scala | 58 +++++++++---------- .../CollectionExpressionsSuite.scala | 3 + 2 files changed, 30 insertions(+), 31 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 6ed02574ed2bb..c14316cc9ab24 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -4014,32 +4014,30 @@ object ArraySetUtils { abstract class ArraySetUtils extends BinaryExpression with ExpectsInputTypes { def typeId: Int - def array1: Expression - def array2: Expression override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, ArrayType) override def checkInputDataTypes(): TypeCheckResult = { val r = super.checkInputDataTypes() if ((r == TypeCheckResult.TypeCheckSuccess) && - (array1.dataType.asInstanceOf[ArrayType].elementType != - array2.dataType.asInstanceOf[ArrayType].elementType)) { + (left.dataType.asInstanceOf[ArrayType].elementType != + right.dataType.asInstanceOf[ArrayType].elementType)) { TypeCheckResult.TypeCheckFailure("Element type in both arrays must be the same") } else { r } } - override def dataType: DataType = array1.dataType + override def dataType: DataType = left.dataType private def elementType = dataType.asInstanceOf[ArrayType].elementType - private def cn = array1.dataType.asInstanceOf[ArrayType].containsNull || - array2.dataType.asInstanceOf[ArrayType].containsNull + private def cn = left.dataType.asInstanceOf[ArrayType].containsNull || + right.dataType.asInstanceOf[ArrayType].containsNull - def intEval(ary: ArrayData, hs1: OpenHashSet[Int]): OpenHashSet[Int] - def longEval(ary: ArrayData, hs1: OpenHashSet[Long]): OpenHashSet[Long] - def genericEval(ary: ArrayData, hs1: OpenHashSet[Any], et: DataType): OpenHashSet[Any] - def codeGen(ctx: CodegenContext, hs1: String, hs: String, len: String, getter: String, i: String, + def intEval(ary: ArrayData, hs2: OpenHashSet[Int]): OpenHashSet[Int] + def longEval(ary: ArrayData, hs2: OpenHashSet[Long]): OpenHashSet[Long] + def genericEval(ary: ArrayData, hs2: OpenHashSet[Any], et: DataType): OpenHashSet[Any] + def codeGen(ctx: CodegenContext, hs2: String, hs: String, len: String, getter: String, i: String, postFix: String, newOpenHashSet: String): String override def nullSafeEval(input1: Any, input2: Any): Any = { @@ -4050,31 +4048,31 @@ abstract class ArraySetUtils extends BinaryExpression with ExpectsInputTypes { elementType match { case IntegerType => // avoid boxing of primitive int array elements - val hs1 = new OpenHashSet[Int] + val hs2 = new OpenHashSet[Int] var i = 0 - while (i < ary1.numElements()) { - hs1.add(ary1.getInt(i)) + while (i < ary2.numElements()) { + hs2.add(ary2.getInt(i)) i += 1 } - ArraySetUtils.toUnsafeIntArray(intEval(ary2, hs1)) + ArraySetUtils.toUnsafeIntArray(intEval(ary1, hs2)) case LongType => // avoid boxing of primitive long array elements - val hs1 = new OpenHashSet[Long] + val hs2 = new OpenHashSet[Long] var i = 0 - while (i < ary1.numElements()) { - hs1.add(ary1.getLong(i)) + while (i < ary2.numElements()) { + hs2.add(ary2.getLong(i)) i += 1 } - ArraySetUtils.toUnsafeLongArray(longEval(ary2, hs1)) + ArraySetUtils.toUnsafeLongArray(longEval(ary1, hs2)) case _ => var hs: OpenHashSet[Any] = null val hs1 = new OpenHashSet[Any] var i = 0 - while (i < ary1.numElements()) { - hs1.add(ary1.get(i, elementType)) + while (i < ary2.numElements()) { + hs1.add(ary2.get(i, elementType)) i += 1 } - new GenericArrayData(genericEval(ary2, hs1, elementType).iterator.toArray) + new GenericArrayData(genericEval(ary1, hs1, elementType).iterator.toArray) } } else { if (typeId == ArraySetUtils.kindUnion) { @@ -4082,7 +4080,7 @@ abstract class ArraySetUtils extends BinaryExpression with ExpectsInputTypes { } else if (typeId == ArraySetUtils.kindIntersect) { ArraySetUtils.arrayIntersect(ary1, ary2, elementType) } else { - ArraySetUtils.arrayExcept(ary2, ary1, elementType) + ArraySetUtils.arrayExcept(ary1, ary2, elementType) } } } @@ -4112,18 +4110,18 @@ abstract class ArraySetUtils extends BinaryExpression with ExpectsInputTypes { } val hs = ctx.freshName("hs") - val hs1 = ctx.freshName("hs1") + val hs2 = ctx.freshName("hs2") val invalidPos = ctx.freshName("invalidPos") val pos = ctx.freshName("pos") val ary = ctx.freshName("ary") nullSafeCodeGen(ctx, ev, (ary1, ary2) => { if (classTag != "") { - val secondLoop = codeGen(ctx, hs1, hs, s"$ary2.numElements()", s"$ary2.$getter", i, + val secondLoop = codeGen(ctx, hs2, hs, s"$ary1.numElements()", s"$ary1.$getter", i, postFix, s"new $openHashSet$postFix($classTag)") s""" - |$openHashSet $hs1 = new $openHashSet$postFix($classTag); - |for (int $i = 0; $i < $ary1.numElements(); $i++) { - | $hs1.add$postFix($ary1.$getter); + |$openHashSet $hs2 = new $openHashSet$postFix($classTag); + |for (int $i = 0; $i < $ary2.numElements(); $i++) { + | $hs2.add$postFix($ary2.$getter); |} |$secondLoop |$javaTypeName[] $ary = new $javaTypeName[$hs.size()]; @@ -4145,7 +4143,7 @@ abstract class ArraySetUtils extends BinaryExpression with ExpectsInputTypes { } else { "Except" } - s"${ev.value} = $arraySetUtils$$.MODULE$$.array$setOp($ary2, $ary1, $et);" + s"${ev.value} = $arraySetUtils$$.MODULE$$.array$setOp($ary1, $ary2, $et);" } }) } @@ -4167,8 +4165,6 @@ abstract class ArraySetUtils extends BinaryExpression with ExpectsInputTypes { since = "2.4.0") case class ArrayExcept(left: Expression, right: Expression) extends ArraySetUtils { override def typeId: Int = ArraySetUtils.kindExcept - override def array1: Expression = right - override def array2: Expression = left override def intEval(ary: ArrayData, hs1: OpenHashSet[Int]): OpenHashSet[Int] = { val hs = new OpenHashSet[Int] diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index b2ec8ea925c97..5b9b8c2fc0d21 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1528,6 +1528,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val a24 = Literal.create(Seq("c", "d", "a", "f"), ArrayType(StringType, false)) val a30 = Literal.create(Seq(null, null), ArrayType(NullType)) + val a31 = Literal.create(null, ArrayType(StringType)) checkEvaluation(ArrayExcept(a00, a01), UnsafeArrayData.fromPrimitiveArray(Array(1))) checkEvaluation(ArrayExcept(a02, a01), Seq(1)) @@ -1549,5 +1550,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayExcept(a23, a24), Seq("b")) checkEvaluation(ArrayExcept(a30, a30), Seq.empty) + checkEvaluation(ArrayExcept(a20, a31), null) + checkEvaluation(ArrayExcept(a31, a20), null) } } From 71374a838aa6db4f362ec3b76cf0729c84bc218f Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sat, 14 Jul 2018 07:09:52 +0100 Subject: [PATCH 07/29] update --- .../main/scala/org/apache/spark/package.scala | 6 +- .../expressions/collectionOperations.scala | 491 ++++++++++-------- .../CollectionExpressionsSuite.scala | 131 +++-- .../spark/sql/DataFrameFunctionsSuite.scala | 41 +- 4 files changed, 379 insertions(+), 290 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/package.scala b/core/src/main/scala/org/apache/spark/package.scala index 8058a4d5dbdea..f46808670d445 100644 --- a/core/src/main/scala/org/apache/spark/package.scala +++ b/core/src/main/scala/org/apache/spark/package.scala @@ -57,7 +57,7 @@ package object spark { val resourceStream = Thread.currentThread().getContextClassLoader. getResourceAsStream("spark-version-info.properties") if (resourceStream == null) { - throw new SparkException("Could not find spark-version-info.properties") + // throw new SparkException("Could not find spark-version-info.properties") } try { @@ -74,7 +74,9 @@ package object spark { ) } catch { case e: Exception => - throw new SparkException("Error loading properties from spark-version-info.properties", e) + val unknownProp = "" + (unknownProp, unknownProp, unknownProp, unknownProp, unknownProp, unknownProp) + // throw new SparkException("Error loading properties from spark-version-info.properties", e) } finally { if (resourceStream != null) { try { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index c14316cc9ab24..324d88f41ed6e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3649,7 +3649,7 @@ case class ArrayDistinct(child: Expression) } /** - * Will become common base class for [[ArrayUnion]], ArrayIntersect, and ArrayExcept. + * Will become common base class for [[ArrayUnion]], ArrayIntersect, and [[ArrayExcept]]. */ abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast { override def dataType: DataType = { @@ -3967,266 +3967,327 @@ object ArrayUnion { } } -object ArraySetUtils { - val kindUnion = 1 - val kindIntersect = 2 - val kindExcept = 3 +/** + * Returns an array of the elements in the intersect of x and y, without duplicates + */ +@ExpressionDescription( + usage = """ + _FUNC_(array1, array2) - Returns an array of the elements in array1 but not in array2, + without duplicates. + """, + examples = """ + Examples:Fun + > SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5)); + array(2) + """, + since = "2.4.0") +case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike { + var hsInt: OpenHashSet[Int] = _ + var hsLong: OpenHashSet[Long] = _ - def toUnsafeIntArray(hs: OpenHashSet[Int]): UnsafeArrayData = { - val array = new Array[Int](hs.size) - var pos = hs.nextPos(0) - var i = 0 - while (pos != OpenHashSet.INVALID_POS) { - array(i) = hs.getValue(pos) - pos = hs.nextPos(pos + 1) - i += 1 + def assignInt(array: ArrayData, idx: Int, resultArray: ArrayData, pos: Int): Boolean = { + val elem = array.getInt(idx) + if (!hsInt.contains(elem)) { + if (resultArray != null) { + resultArray.setInt(pos, elem) + } + hsInt.add(elem) + true + } else { + false } - UnsafeArrayData.fromPrimitiveArray(array) } - def toUnsafeLongArray(hs: OpenHashSet[Long]): UnsafeArrayData = { - val array = new Array[Long](hs.size) - var pos = hs.nextPos(0) - var i = 0 - while (pos != OpenHashSet.INVALID_POS) { - array(i) = hs.getValue(pos) - pos = hs.nextPos(pos + 1) - i += 1 + def assignLong(array: ArrayData, idx: Int, resultArray: ArrayData, pos: Int): Boolean = { + val elem = array.getLong(idx) + if (!hsLong.contains(elem)) { + if (resultArray != null) { + resultArray.setLong(pos, elem) + } + hsLong.add(elem) + true + } else { + false } - UnsafeArrayData.fromPrimitiveArray(array) - } - - def arrayUnion(array1: ArrayData, array2: ArrayData, et: DataType): ArrayData = { - new GenericArrayData(array1.toArray[AnyRef](et).union(array2.toArray[AnyRef](et)) - .distinct.asInstanceOf[Array[Any]]) } - def arrayIntersect(array1: ArrayData, array2: ArrayData, et: DataType): ArrayData = { - new GenericArrayData(array1.toArray[AnyRef](et).intersect(array2.toArray[AnyRef](et)) - .distinct.asInstanceOf[Array[Any]]) - } - - def arrayExcept(array1: ArrayData, array2: ArrayData, et: DataType): ArrayData = { - new GenericArrayData(array1.toArray[AnyRef](et).diff(array2.toArray[AnyRef](et)) - .distinct.asInstanceOf[Array[Any]]) - } -} - -abstract class ArraySetUtils extends BinaryExpression with ExpectsInputTypes { - def typeId: Int - - override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, ArrayType) - - override def checkInputDataTypes(): TypeCheckResult = { - val r = super.checkInputDataTypes() - if ((r == TypeCheckResult.TypeCheckSuccess) && - (left.dataType.asInstanceOf[ArrayType].elementType != - right.dataType.asInstanceOf[ArrayType].elementType)) { - TypeCheckResult.TypeCheckFailure("Element type in both arrays must be the same") - } else { - r + def evalIntLongPrimitiveType( + array1: ArrayData, + array2: ArrayData, + resultArray: ArrayData, + isLongType: Boolean): Int = { + // store elements into resultArray + var exceptNullElement = true + var i = 0 + while (i < array2.numElements()) { + if (array2.isNullAt(i)) { + exceptNullElement = false + } else { + val assigned = if (!isLongType) { + hsInt.add(array2.getInt(i)) + } else { + hsLong.add(array2.getLong(i)) + } + } + i += 1 + } + var pos = 0 + i = 0 + while (i < array1.numElements()) { + if (array1.isNullAt(i)) { + if (exceptNullElement) { + if (resultArray != null) { + resultArray.setNullAt(pos) + } + pos += 1 + exceptNullElement = false + } + } else { + val assigned = if (!isLongType) { + assignInt(array1, i, resultArray, pos) + } else { + assignLong(array1, i, resultArray, pos) + } + if (assigned) { + pos += 1 + } + } + i += 1 } + pos } - override def dataType: DataType = left.dataType - - private def elementType = dataType.asInstanceOf[ArrayType].elementType - private def cn = left.dataType.asInstanceOf[ArrayType].containsNull || - right.dataType.asInstanceOf[ArrayType].containsNull - - def intEval(ary: ArrayData, hs2: OpenHashSet[Int]): OpenHashSet[Int] - def longEval(ary: ArrayData, hs2: OpenHashSet[Long]): OpenHashSet[Long] - def genericEval(ary: ArrayData, hs2: OpenHashSet[Any], et: DataType): OpenHashSet[Any] - def codeGen(ctx: CodegenContext, hs2: String, hs: String, len: String, getter: String, i: String, - postFix: String, newOpenHashSet: String): String - override def nullSafeEval(input1: Any, input2: Any): Any = { - val ary1 = input1.asInstanceOf[ArrayData] - val ary2 = input2.asInstanceOf[ArrayData] + val array1 = input1.asInstanceOf[ArrayData] + val array2 = input2.asInstanceOf[ArrayData] - if (!cn) { + if (elementTypeSupportEquals) { elementType match { case IntegerType => // avoid boxing of primitive int array elements - val hs2 = new OpenHashSet[Int] - var i = 0 - while (i < ary2.numElements()) { - hs2.add(ary2.getInt(i)) - i += 1 + // calculate result array size + hsInt = new OpenHashSet[Int] + val elements = evalIntLongPrimitiveType(array1, array2, null, false) + // allocate result array + hsInt = new OpenHashSet[Int] + val resultArray = if (UnsafeArrayData.shouldUseGenericArrayData( + IntegerType.defaultSize, elements)) { + new GenericArrayData(new Array[Any](elements)) + } else { + UnsafeArrayData.forPrimitiveArray( + Platform.INT_ARRAY_OFFSET, elements, IntegerType.defaultSize) } - ArraySetUtils.toUnsafeIntArray(intEval(ary1, hs2)) + // assign elements into the result array + evalIntLongPrimitiveType(array1, array2, resultArray, false) + resultArray case LongType => // avoid boxing of primitive long array elements - val hs2 = new OpenHashSet[Long] - var i = 0 - while (i < ary2.numElements()) { - hs2.add(ary2.getLong(i)) - i += 1 + // calculate result array size + hsLong = new OpenHashSet[Long] + val elements = evalIntLongPrimitiveType(array1, array2, null, true) + // allocate result array + hsLong = new OpenHashSet[Long] + val resultArray = if (UnsafeArrayData.shouldUseGenericArrayData( + LongType.defaultSize, elements)) { + new GenericArrayData(new Array[Any](elements)) + } else { + UnsafeArrayData.forPrimitiveArray( + Platform.LONG_ARRAY_OFFSET, elements, LongType.defaultSize) } - ArraySetUtils.toUnsafeLongArray(longEval(ary1, hs2)) + // assign elements into the result array + evalIntLongPrimitiveType(array1, array2, resultArray, true) + resultArray case _ => - var hs: OpenHashSet[Any] = null - val hs1 = new OpenHashSet[Any] + val hs = new OpenHashSet[Any] + var exceptNullElement = true var i = 0 - while (i < ary2.numElements()) { - hs1.add(ary2.get(i, elementType)) + while (i < array2.numElements()) { + if (array2.isNullAt(i)) { + exceptNullElement = false + } else { + val elem = array2.get(i, elementType) + hs.add(elem) + } i += 1 } - new GenericArrayData(genericEval(ary1, hs1, elementType).iterator.toArray) + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + i = 0 + while (i < array1.numElements()) { + if (array1.isNullAt(i)) { + if (exceptNullElement) { + arrayBuffer += null + exceptNullElement = false + } + } else { + val elem = array1.get(i, elementType) + if (!hs.contains(elem)) { + arrayBuffer += elem + hs.add(elem) + } + } + i += 1 + } + new GenericArrayData(arrayBuffer) } } else { - if (typeId == ArraySetUtils.kindUnion) { - ArraySetUtils.arrayUnion(ary1, ary2, elementType) - } else if (typeId == ArraySetUtils.kindIntersect) { - ArraySetUtils.arrayIntersect(ary1, ary2, elementType) - } else { - ArraySetUtils.arrayExcept(ary1, ary2, elementType) - } + ArrayExcept.exceptOrdering(array1, array2, elementType, ordering) } } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val i = ctx.freshName("i") - val arraySetUtils = "org.apache.spark.sql.catalyst.expressions.ArraySetUtils" - val genericArrayData = classOf[GenericArrayData].getName - val unsafeArrayData = classOf[UnsafeArrayData].getName - val openHashSet = classOf[OpenHashSet[_]].getName - val et = s"org.apache.spark.sql.types.DataTypes.$elementType" - val (postFix, classTag, getter, arrayBuilder, javaTypeName) = if (!cn) { - val ptName = CodeGenerator.primitiveTypeName(elementType) - elementType match { - case ByteType | ShortType | IntegerType => - (s"$$mcI$$sp", s"scala.reflect.ClassTag$$.MODULE$$.$ptName()", s"get$ptName($i)", - s"$unsafeArrayData.fromPrimitiveArray", CodeGenerator.javaType(elementType)) - case LongType => - (s"$$mcJ$$sp", s"scala.reflect.ClassTag$$.MODULE$$.$ptName()", s"get$ptName($i)", - s"$unsafeArrayData.fromPrimitiveArray", "long") - case _ => - ("", s"scala.reflect.ClassTag$$.MODULE$$.Object()", s"get($i, $et)", - s"new $genericArrayData", "Object") + val pos = ctx.freshName("pos") + val value = ctx.freshName("value") + val size = ctx.freshName("size") + val (postFix, openHashElementType, getter, setter, javaTypeName, castOp, arrayBuilder) = + if (elementTypeSupportEquals) { + elementType match { + case ByteType | ShortType | IntegerType | LongType => + val ptName = CodeGenerator.primitiveTypeName(elementType) + val unsafeArray = ctx.freshName("unsafeArray") + (if (elementType == LongType) s"$$mcJ$$sp" else s"$$mcI$$sp", + if (elementType == LongType) "Long" else "Int", + s"get$ptName($i)", s"set$ptName($pos, $value)", CodeGenerator.javaType(elementType), + if (elementType == LongType) "(long)" else "(int)", + s""" + |${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")} + |${ev.value} = $unsafeArray; + """.stripMargin) + case _ => + val genericArrayData = classOf[GenericArrayData].getName + val et = ctx.addReferenceObj("elementType", elementType) + ("", "Object", + s"get($i, $et)", s"update($pos, $value)", "Object", "", + s"${ev.value} = new $genericArrayData(new Object[$size]);") + } + } else { + ("", "", "", "", "", "", "") } - } else { - ("", "", "", "", "") - } - val hs = ctx.freshName("hs") - val hs2 = ctx.freshName("hs2") - val invalidPos = ctx.freshName("invalidPos") - val pos = ctx.freshName("pos") - val ary = ctx.freshName("ary") - nullSafeCodeGen(ctx, ev, (ary1, ary2) => { - if (classTag != "") { - val secondLoop = codeGen(ctx, hs2, hs, s"$ary1.numElements()", s"$ary1.$getter", i, - postFix, s"new $openHashSet$postFix($classTag)") + nullSafeCodeGen(ctx, ev, (array1, array2) => { + if (openHashElementType != "") { + // Here, we ensure elementTypeSupportEquals is true + val exceptNullElement = ctx.freshName("exceptNullElement") + val openHashSet = classOf[OpenHashSet[_]].getName + val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$openHashElementType()" + val hs = ctx.freshName("hs") + val arrayData = classOf[ArrayData].getName + val arrays = ctx.freshName("arrays") + val array = ctx.freshName("array") + val arrayDataIdx = ctx.freshName("arrayDataIdx") s""" - |$openHashSet $hs2 = new $openHashSet$postFix($classTag); - |for (int $i = 0; $i < $ary2.numElements(); $i++) { - | $hs2.add$postFix($ary2.$getter); + |$openHashSet $hs = new $openHashSet$postFix($classTag); + |boolean $exceptNullElement = true; + |int $size = 0; + |for (int $i = 0; $i < $array2.numElements(); $i++) { + | if ($array2.isNullAt($i)) { + | $exceptNullElement = false; + | } else { + | $hs.add$postFix($array2.$getter); + | } |} - |$secondLoop - |$javaTypeName[] $ary = new $javaTypeName[$hs.size()]; - |int $invalidPos = $openHashSet.INVALID_POS(); - |int $pos = $hs.nextPos(0); - |int $i = 0; - |while ($pos != $invalidPos) { - | $ary[$i] = ($javaTypeName) $hs.getValue$postFix($pos); - | $pos = $hs.nextPos($pos + 1); - | $i++; + |for (int $i = 0; $i < $array1.numElements(); $i++) { + | if ($array1.isNullAt($i)) { + | if ($exceptNullElement) { + | $size++; + | $exceptNullElement = false; + | } + | } else { + | $javaTypeName $value = $array1.$getter; + | if (!$hs.contains($castOp $value)) { + | $hs.add$postFix($value); + | $size++; + | } + | } + |} + |$arrayBuilder + |$hs = new $openHashSet$postFix($classTag); + |$exceptNullElement = true; + |int $pos = 0; + |for (int $i = 0; $i < $array2.numElements(); $i++) { + | if ($array2.isNullAt($i)) { + | $exceptNullElement = false; + | } else { + | $hs.add$postFix($array2.$getter); + | } + |} + |for (int $i = 0; $i < $array1.numElements(); $i++) { + | if ($array1.isNullAt($i)) { + | if ($exceptNullElement) { + | ${ev.value}.setNullAt($pos++); + | $exceptNullElement = false; + | } + | } else { + | $javaTypeName $value = $array1.$getter; + | if (!$hs.contains($castOp $value)) { + | $hs.add$postFix($value); + | ${ev.value}.$setter; + | $pos++; + | } + | } |} - |${ev.value} = $arrayBuilder($ary); """.stripMargin } else { - val setOp = if (typeId == ArraySetUtils.kindUnion) { - "Union" - } else if (typeId == ArraySetUtils.kindIntersect) { - "Intersect" - } else { - "Except" - } - s"${ev.value} = $arraySetUtils$$.MODULE$$.array$setOp($ary1, $ary2, $et);" + val arrayExcept = classOf[ArrayExcept].getName + val et = ctx.addReferenceObj("elementTypeIntersect", elementType) + val order = ctx.addReferenceObj("orderingIntersect", ordering) + val method = "exceptOrdering" + s"${ev.value} = $arrayExcept$$.MODULE$$.$method($array1, $array2, $et, $order);" } }) } -} - -/** - * Returns an array of the elements in the union of x and y, without duplicates - */ -@ExpressionDescription( - usage = """ - _FUNC_(array1, array2) - Returns an array of the elements in array1 but not in array2, - without duplicates. The order of elements in the result is not determined. - """, - examples = """ - Examples: - > SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5)); - array(2) - """, - since = "2.4.0") -case class ArrayExcept(left: Expression, right: Expression) extends ArraySetUtils { - override def typeId: Int = ArraySetUtils.kindExcept - override def intEval(ary: ArrayData, hs1: OpenHashSet[Int]): OpenHashSet[Int] = { - val hs = new OpenHashSet[Int] - var i = 0 - while (i < ary.numElements()) { - val k = ary.getInt(i) - if (!hs1.contains(k)) { - hs.add(k) - } - i += 1 - } - hs - } + override def prettyName: String = "array_except" +} - override def longEval(ary: ArrayData, hs1: OpenHashSet[Long]): OpenHashSet[Long] = { - val hs = new OpenHashSet[Long] +object ArrayExcept { + def exceptOrdering( + array1: ArrayData, + array2: ArrayData, + elementType: DataType, + ordering: Ordering[Any]): ArrayData = { + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + var exceptNullElement = false var i = 0 - while (i < ary.numElements()) { - val k = ary.getLong(i) - if (!hs1.contains(k)) { - hs.add(k) + while (i < array1.numElements()) { + var found = false + var elem1 = array1.get(i, elementType) + if (array1.isNullAt(i)) { + if (!exceptNullElement) { + var j = 0 + while (!found && j < array2.numElements()) { + found = array2.isNullAt(j) + j += 1 + } + // array2 is scaned only once for null element + exceptNullElement = true + } else { + found = true + } + } else { + var j = 0 + while (!found && j < array2.numElements()) { + if (!array2.isNullAt(j)) { + val elem2 = array2.get(j, elementType) + found = ordering.equiv(elem1, elem2) + } + j += 1 + } + if (!found) { + // check whether elem1 is already stored in arrayBuffer + var k = 0 + while (!found && k < arrayBuffer.size) { + val va = arrayBuffer(k) + found = (va != null) && ordering.equiv(va, elem1) + k += 1 + } + } } - i += 1 - } - hs - } - - override def genericEval( - ary: ArrayData, - hs1: OpenHashSet[Any], - et: DataType): OpenHashSet[Any] = { - val hs = new OpenHashSet[Any] - var i = 0 - while (i < ary.numElements()) { - val k = ary.get(i, et) - if (!hs1.contains(k)) { - hs.add(k) + if (!found) { + arrayBuffer += elem1 } i += 1 } - hs - } - - override def codeGen( - ctx: CodegenContext, - hs1: String, - hs: String, - len: String, - getter: String, - i: String, - postFix: String, - newOpenHashSet: String): String = { - val openHashSet = classOf[OpenHashSet[_]].getName - s""" - |$openHashSet $hs = $newOpenHashSet; - |for (int $i = 0; $i < $len; $i++) { - | if (!$hs1.contains$postFix($getter)) { - | $hs.add$postFix($getter); - | } - |} - """.stripMargin + new GenericArrayData(arrayBuffer) } - - override def prettyName: String = "array_except" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 5b9b8c2fc0d21..1ca4aa37f76c4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1032,9 +1032,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper intercept[Exception] { checkEvaluation(ElementAt(a0, Literal(0)), null) }.getMessage.contains("SQL array indices start at 1") - intercept[Exception] { - checkEvaluation(ElementAt(a0, Literal(1.1)), null) - } + intercept[Exception] { checkEvaluation(ElementAt(a0, Literal(1.1)), null) } checkEvaluation(ElementAt(a0, Literal(4)), null) checkEvaluation(ElementAt(a0, Literal(-4)), null) @@ -1507,50 +1505,107 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper } test("Array Except") { - val a00 = Literal.create(Seq(1, 2, 4), ArrayType(IntegerType, false)) + val a00 = Literal.create(Seq(1, 2, 4, 3), ArrayType(IntegerType, false)) val a01 = Literal.create(Seq(4, 2), ArrayType(IntegerType, false)) - val a02 = Literal.create(Seq(1, 2, 4), ArrayType(IntegerType)) - val a03 = Literal.create(Seq(1, 2, null, 4, 5), ArrayType(IntegerType)) - val a04 = Literal.create(Seq(-5, 4, null, 2, -1), ArrayType(IntegerType)) - val a05 = Literal.create(Seq.empty[Int], ArrayType(IntegerType)) - - val a10 = Literal.create(Seq(1L, 2L, 4L), ArrayType(LongType, false)) + val a02 = Literal.create(Seq(1, 2, 4, 2), ArrayType(IntegerType, false)) + val a03 = Literal.create(Seq(4, 2, 4), ArrayType(IntegerType, false)) + val a04 = Literal.create(Seq(1, 2, null, 4, 5, 1), ArrayType(IntegerType, true)) + val a05 = Literal.create(Seq(-5, 4, null, 2, -1), ArrayType(IntegerType, true)) + val a06 = Literal.create(Seq.empty[Int], ArrayType(IntegerType, false)) + val ab0 = Literal.create(Seq[Byte](1, 2, 3, 2), ArrayType(ByteType, false)) + val ab1 = Literal.create(Seq[Byte](4, 2, 4), ArrayType(ByteType, false)) + val as0 = Literal.create(Seq[Short](1, 2, 3, 2), ArrayType(ShortType, false)) + val as1 = Literal.create(Seq[Short](4, 2, 4), ArrayType(ShortType, false)) + + val a10 = Literal.create(Seq(1L, 2L, 4L, 3L), ArrayType(LongType, false)) val a11 = Literal.create(Seq(4L, 2L), ArrayType(LongType, false)) - val a12 = Literal.create(Seq(1L, 2L, 4L), ArrayType(LongType)) - val a13 = Literal.create(Seq(1L, 2L, null, 4L, 5L), ArrayType(LongType)) - val a14 = Literal.create(Seq(-5L, 4L, null, 2L, -1L), ArrayType(LongType)) - val a15 = Literal.create(Seq.empty[Long], ArrayType(LongType)) - - val a20 = Literal.create(Seq("b", "a", "c"), ArrayType(StringType)) - val a21 = Literal.create(Seq("c", null, "a", "f"), ArrayType(StringType)) - val a22 = Literal.create(Seq("b", null, "a", "g"), ArrayType(StringType)) - val a23 = Literal.create(Seq("b", "a", "c"), ArrayType(StringType, false)) - val a24 = Literal.create(Seq("c", "d", "a", "f"), ArrayType(StringType, false)) - - val a30 = Literal.create(Seq(null, null), ArrayType(NullType)) + val a12 = Literal.create(Seq(1L, 2L, 4L, 2L), ArrayType(LongType, false)) + val a13 = Literal.create(Seq(4L, 2L), ArrayType(LongType, false)) + val a14 = Literal.create(Seq(1L, 2L, null, 4L, 5L, 1L), ArrayType(LongType, true)) + val a15 = Literal.create(Seq(-5L, 4L, null, 2L, -1L), ArrayType(LongType, true)) + val a16 = Literal.create(Seq.empty[Long], ArrayType(LongType, false)) + + val a20 = Literal.create(Seq("b", "a", "c", "d"), ArrayType(StringType, false)) + val a21 = Literal.create(Seq("c", "a"), ArrayType(StringType, false)) + val a22 = Literal.create(Seq("b", "a", "c", "a"), ArrayType(StringType, false)) + val a23 = Literal.create(Seq("c", "a", "c"), ArrayType(StringType, false)) + val a24 = Literal.create(Seq("c", null, "a", "f", "c"), ArrayType(StringType, true)) + val a25 = Literal.create(Seq("b", null, "a", "g"), ArrayType(StringType, true)) + val a26 = Literal.create(Seq.empty[String], ArrayType(StringType, false)) + + val a30 = Literal.create(Seq(null, null), ArrayType(IntegerType)) val a31 = Literal.create(null, ArrayType(StringType)) - checkEvaluation(ArrayExcept(a00, a01), UnsafeArrayData.fromPrimitiveArray(Array(1))) + checkEvaluation(ArrayExcept(a00, a01), Seq(1, 3)) checkEvaluation(ArrayExcept(a02, a01), Seq(1)) - checkEvaluation(ArrayExcept(a03, a02), Seq(null, 5)) - checkEvaluation(ArrayExcept(a03, a04), Seq(1, 5)) - checkEvaluation(ArrayExcept(a03, a05), Seq(1, 2, null, 4, 5)) - checkEvaluation(ArrayExcept(a05, a03), Seq.empty) - - checkEvaluation(ArrayExcept(a10, a11), UnsafeArrayData.fromPrimitiveArray(Array(1L))) + checkEvaluation(ArrayExcept(a02, a02), Seq.empty) + checkEvaluation(ArrayExcept(a02, a03), Seq(1)) + checkEvaluation(ArrayExcept(a04, a02), Seq(null, 5)) + checkEvaluation(ArrayExcept(a04, a05), Seq(1, 5)) + checkEvaluation(ArrayExcept(a04, a06), Seq(1, 2, null, 4, 5)) + checkEvaluation(ArrayExcept(a06, a04), Seq.empty) + checkEvaluation(ArrayExcept(ab0, ab1), Seq[Byte](1, 3)) + checkEvaluation(ArrayExcept(as0, as1), Seq[Short](1, 3)) + + checkEvaluation(ArrayExcept(a10, a11), Seq(1L, 3L)) checkEvaluation(ArrayExcept(a12, a11), Seq(1L)) - checkEvaluation(ArrayExcept(a13, a12), Seq(null, 5L)) - checkEvaluation(ArrayExcept(a13, a14), Seq(1L, 5L)) - checkEvaluation(ArrayExcept(a13, a15), Seq(1L, 2L, null, 4L, 5L)) - checkEvaluation(ArrayExcept(a15, a13), Seq.empty) - - checkEvaluation(ArrayExcept(a20, a21), Seq("b")) - checkEvaluation(ArrayExcept(a21, a22), Seq("c", "f")) - checkEvaluation(ArrayExcept(a22, a23), Seq(null, "g")) - checkEvaluation(ArrayExcept(a23, a24), Seq("b")) + checkEvaluation(ArrayExcept(a12, a12), Seq.empty) + checkEvaluation(ArrayExcept(a12, a13), Seq(1L)) + checkEvaluation(ArrayExcept(a14, a12), Seq(null, 5L)) + checkEvaluation(ArrayExcept(a14, a15), Seq(1L, 5L)) + checkEvaluation(ArrayExcept(a14, a16), Seq(1L, 2L, null, 4L, 5L)) + checkEvaluation(ArrayExcept(a16, a14), Seq.empty) + + checkEvaluation(ArrayExcept(a20, a21), Seq("b", "d")) + checkEvaluation(ArrayExcept(a22, a21), Seq("b")) + checkEvaluation(ArrayExcept(a22, a22), Seq.empty) + checkEvaluation(ArrayExcept(a22, a23), Seq("b")) + checkEvaluation(ArrayExcept(a24, a22), Seq(null, "f")) + checkEvaluation(ArrayExcept(a24, a25), Seq("c", "f")) + checkEvaluation(ArrayExcept(a24, a26), Seq("c", null, "a", "f")) + checkEvaluation(ArrayExcept(a26, a24), Seq.empty) checkEvaluation(ArrayExcept(a30, a30), Seq.empty) checkEvaluation(ArrayExcept(a20, a31), null) checkEvaluation(ArrayExcept(a31, a20), null) + + val b0 = Literal.create( + Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2), Array[Byte](3, 4), Array[Byte](7, 8)), + ArrayType(BinaryType)) + val b1 = Literal.create( + Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](3, 4), Array[Byte](5, 6)), + ArrayType(BinaryType)) + val b2 = Literal.create( + Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](3, 4), Array[Byte](1, 2)), + ArrayType(BinaryType)) + val b3 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](3, 4), null), + ArrayType(BinaryType)) + val b4 = Literal.create(Seq[Array[Byte]](null, Array[Byte](3, 4), null), ArrayType(BinaryType)) + val b5 = Literal.create(Seq.empty, ArrayType(BinaryType)) + val arrayWithBinaryNull = Literal.create(Seq(null), ArrayType(BinaryType)) + + checkEvaluation(ArrayExcept(b0, b1), Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](7, 8))) + checkEvaluation(ArrayExcept(b1, b0), Seq[Array[Byte]](Array[Byte](2, 1))) + checkEvaluation(ArrayExcept(b0, b2), Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](7, 8))) + checkEvaluation(ArrayExcept(b2, b0), Seq.empty) + checkEvaluation(ArrayExcept(b2, b3), Seq[Array[Byte]](Array[Byte](1, 2))) + checkEvaluation(ArrayExcept(b3, b2), Seq[Array[Byte]](Array[Byte](2, 1), null)) + checkEvaluation(ArrayExcept(b3, b4), Seq[Array[Byte]](Array[Byte](2, 1))) + checkEvaluation(ArrayExcept(b4, b3), Seq.empty) + checkEvaluation(ArrayExcept(b4, b5), Seq[Array[Byte]](null, Array[Byte](3, 4))) + checkEvaluation(ArrayExcept(b5, b4), Seq.empty) + checkEvaluation(ArrayExcept(b4, arrayWithBinaryNull), Seq[Array[Byte]](Array[Byte](3, 4))) + + val aa0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4), Seq[Int](1, 2)), + ArrayType(ArrayType(IntegerType))) + val aa1 = Literal.create(Seq[Seq[Int]](Seq[Int](3, 4), Seq[Int](2, 1), Seq[Int](3, 4)), + ArrayType(ArrayType(IntegerType))) + checkEvaluation(ArrayExcept(aa0, aa1), Seq[Seq[Int]](Seq[Int](1, 2))) + checkEvaluation(ArrayExcept(aa1, aa0), Seq[Seq[Int]](Seq[Int](2, 1))) + + assert(ArrayExcept(a00, a01).dataType.asInstanceOf[ArrayType].containsNull === false) + assert(ArrayExcept(a00, a04).dataType.asInstanceOf[ArrayType].containsNull === true) + assert(ArrayExcept(a20, a21).dataType.asInstanceOf[ArrayType].containsNull === false) + assert(ArrayExcept(a20, a24).dataType.asInstanceOf[ArrayType].containsNull === true) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 1979ccd4b77a7..ccabd1842867f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1204,42 +1204,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } -<<<<<<< HEAD // Test with local relation, the Project will be evaluated without codegen simpleTest() // Test with cached relation, the Project will be evaluated with codegen df.cache() simpleTest() -======= - checkAnswer( - df.select(concat($"i1", $"s1")), - Seq(Row(Seq("1", "a", "b", "c")), Row(Seq("1", "0", "a"))) - ) - checkAnswer( - df.select(concat($"i1", $"i2", $"i3")), - Seq(Row(Seq(1, 2, 3, 5, 6)), Row(Seq(1, 0, 2))) - ) - checkAnswer( - df.filter(dummyFilter($"i1")).select(concat($"i1", $"i2", $"i3")), - Seq(Row(Seq(1, 2, 3, 5, 6)), Row(Seq(1, 0, 2))) - ) - checkAnswer( - df.selectExpr("concat(array(1, null), i2, i3)"), - Seq(Row(Seq(1, null, 2, 3, 5, 6)), Row(Seq(1, null, 2))) - ) - checkAnswer( - df.select(concat($"s1", $"s2", $"s3")), - Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null))) - ) - checkAnswer( - df.selectExpr("concat(s1, s2, s3)"), - Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null))) - ) - checkAnswer( - df.filter(dummyFilter($"s1")) select (concat($"s1", $"s2", $"s3")), - Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null))) - ) ->>>>>>> initial commit // Null test cases def nullTest(): Unit = { @@ -1640,10 +1609,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(df5.selectExpr("array_except(a, b)"), ans5) val df6 = Seq((null, null)).toDF("a", "b") - val ans6 = Row(null) - checkAnswer(df6.select(array_except($"a", $"b")), ans6) - checkAnswer(df6.selectExpr("array_except(a, b)"), ans6) - + intercept[AnalysisException] { + df6.select(array_except($"a", $"b")) + } + intercept[AnalysisException] { + df6.selectExpr("array_except(a, b)") + } val df7 = Seq((Array(1), Array("a"))).toDF("a", "b") intercept[AnalysisException] { df7.select(array_except($"a", $"b")) From bfd0509b092f48d075c493b70ba021262edf66af Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 19 Jul 2018 08:50:51 +0100 Subject: [PATCH 08/29] update --- core/src/main/scala/org/apache/spark/package.scala | 6 ++---- .../catalyst/expressions/collectionOperations.scala | 12 ++++++------ .../expressions/CollectionExpressionsSuite.scala | 5 +++-- 3 files changed, 11 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/package.scala b/core/src/main/scala/org/apache/spark/package.scala index f46808670d445..8058a4d5dbdea 100644 --- a/core/src/main/scala/org/apache/spark/package.scala +++ b/core/src/main/scala/org/apache/spark/package.scala @@ -57,7 +57,7 @@ package object spark { val resourceStream = Thread.currentThread().getContextClassLoader. getResourceAsStream("spark-version-info.properties") if (resourceStream == null) { - // throw new SparkException("Could not find spark-version-info.properties") + throw new SparkException("Could not find spark-version-info.properties") } try { @@ -74,9 +74,7 @@ package object spark { ) } catch { case e: Exception => - val unknownProp = "" - (unknownProp, unknownProp, unknownProp, unknownProp, unknownProp, unknownProp) - // throw new SparkException("Error loading properties from spark-version-info.properties", e) + throw new SparkException("Error loading properties from spark-version-info.properties", e) } finally { if (resourceStream != null) { try { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 324d88f41ed6e..8ef9ea87d0485 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3652,11 +3652,6 @@ case class ArrayDistinct(child: Expression) * Will become common base class for [[ArrayUnion]], ArrayIntersect, and [[ArrayExcept]]. */ abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast { - override def dataType: DataType = { - val dataTypes = children.map(_.dataType.asInstanceOf[ArrayType]) - ArrayType(elementType, dataTypes.exists(_.containsNull)) - } - override def checkInputDataTypes(): TypeCheckResult = { val typeCheckResult = super.checkInputDataTypes() if (typeCheckResult.isSuccess) { @@ -3700,7 +3695,8 @@ object ArraySetLike { array(1, 2, 3, 5) """, since = "2.4.0") -case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike { +case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike + with ComplexTypeMergingExpression { var hsInt: OpenHashSet[Int] = _ var hsLong: OpenHashSet[Long] = _ @@ -3982,6 +3978,10 @@ object ArrayUnion { """, since = "2.4.0") case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike { + override def dataType: DataType = ArrayType(elementType, + left.dataType.asInstanceOf[ArrayType].containsNull && + !right.dataType.asInstanceOf[ArrayType].containsNull) + var hsInt: OpenHashSet[Int] = _ var hsLong: OpenHashSet[Long] = _ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 1ca4aa37f76c4..392fc5555ca9e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1604,8 +1604,9 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayExcept(aa1, aa0), Seq[Seq[Int]](Seq[Int](2, 1))) assert(ArrayExcept(a00, a01).dataType.asInstanceOf[ArrayType].containsNull === false) - assert(ArrayExcept(a00, a04).dataType.asInstanceOf[ArrayType].containsNull === true) + assert(ArrayExcept(a04, a02).dataType.asInstanceOf[ArrayType].containsNull === true) + assert(ArrayExcept(a04, a05).dataType.asInstanceOf[ArrayType].containsNull === false) assert(ArrayExcept(a20, a21).dataType.asInstanceOf[ArrayType].containsNull === false) - assert(ArrayExcept(a20, a24).dataType.asInstanceOf[ArrayType].containsNull === true) + assert(ArrayExcept(a24, a22).dataType.asInstanceOf[ArrayType].containsNull === true) } } From 7966cc366fe4be129e6ddb6fc2ec204954062bd9 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sat, 21 Jul 2018 19:48:28 +0100 Subject: [PATCH 09/29] address review comment --- .../catalyst/expressions/collectionOperations.scala | 7 +++---- .../expressions/CollectionExpressionsSuite.scala | 2 +- .../org/apache/spark/sql/DataFrameFunctionsSuite.scala | 10 ++++++++++ 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 8ef9ea87d0485..61bf4d434fd39 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3972,15 +3972,14 @@ object ArrayUnion { without duplicates. """, examples = """ - Examples:Fun + Examples: > SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5)); array(2) """, since = "2.4.0") case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike { - override def dataType: DataType = ArrayType(elementType, - left.dataType.asInstanceOf[ArrayType].containsNull && - !right.dataType.asInstanceOf[ArrayType].containsNull) + override def dataType: DataType = + ArrayType(elementType, left.dataType.asInstanceOf[ArrayType].containsNull) var hsInt: OpenHashSet[Int] = _ var hsLong: OpenHashSet[Long] = _ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 392fc5555ca9e..603f9889e6809 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1605,7 +1605,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper assert(ArrayExcept(a00, a01).dataType.asInstanceOf[ArrayType].containsNull === false) assert(ArrayExcept(a04, a02).dataType.asInstanceOf[ArrayType].containsNull === true) - assert(ArrayExcept(a04, a05).dataType.asInstanceOf[ArrayType].containsNull === false) + assert(ArrayExcept(a04, a05).dataType.asInstanceOf[ArrayType].containsNull === true) assert(ArrayExcept(a20, a21).dataType.asInstanceOf[ArrayType].containsNull === false) assert(ArrayExcept(a24, a22).dataType.asInstanceOf[ArrayType].containsNull === true) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index ccabd1842867f..bda149dfa2281 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1636,6 +1636,16 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { intercept[AnalysisException] { df9.selectExpr("array_except(a, b)") } + + val df10 = Seq( + (Array[Integer](1, 2), Array[Integer](2)), + (Array[Integer](1, 2), Array[Integer](1, null)), + (Array[Integer](1, null, 3) , Array[Integer](1, 2)), + (Array[Integer](1, null), Array[Integer](2, null)) + ).toDF("a", "b") + val result10 = df10.select(array_except($"a", $"b")) + val expectedType10 = ArrayType(IntegerType, containsNull = true) + assert(result10.first.schema(0).dataType === expectedType10) } private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { From c68a75f69c6e1507539176d44c9e17b1034ad15b Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sat, 21 Jul 2018 21:02:11 +0100 Subject: [PATCH 10/29] fix scala style --- .../scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index bda149dfa2281..0ea444ca14709 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1640,7 +1640,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { val df10 = Seq( (Array[Integer](1, 2), Array[Integer](2)), (Array[Integer](1, 2), Array[Integer](1, null)), - (Array[Integer](1, null, 3) , Array[Integer](1, 2)), + (Array[Integer](1, null, 3), Array[Integer](1, 2)), (Array[Integer](1, null), Array[Integer](2, null)) ).toDF("a", "b") val result10 = df10.select(array_except($"a", $"b")) From 325e22084849b013ca751cfb613ad87ef51ce0d7 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 24 Jul 2018 06:46:04 +0100 Subject: [PATCH 11/29] address review comments --- .../catalyst/expressions/collectionOperations.scala | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 61bf4d434fd39..9160f253d1753 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3978,8 +3978,7 @@ object ArrayUnion { """, since = "2.4.0") case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike { - override def dataType: DataType = - ArrayType(elementType, left.dataType.asInstanceOf[ArrayType].containsNull) + override def dataType: DataType = left.dataType var hsInt: OpenHashSet[Int] = _ var hsLong: OpenHashSet[Long] = _ @@ -4016,11 +4015,11 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike resultArray: ArrayData, isLongType: Boolean): Int = { // store elements into resultArray - var exceptNullElement = true + var notFoundNullElement = true var i = 0 while (i < array2.numElements()) { if (array2.isNullAt(i)) { - exceptNullElement = false + notFoundNullElement = false } else { val assigned = if (!isLongType) { hsInt.add(array2.getInt(i)) @@ -4034,12 +4033,12 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike i = 0 while (i < array1.numElements()) { if (array1.isNullAt(i)) { - if (exceptNullElement) { + if (notFoundNullElement) { if (resultArray != null) { resultArray.setNullAt(pos) } pos += 1 - exceptNullElement = false + notFoundNullElement = false } } else { val assigned = if (!isLongType) { From 27d0f7a88c814582db328a510a54c22fc9c360ae Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 26 Jul 2018 03:36:27 +0100 Subject: [PATCH 12/29] address review comment --- .../expressions/collectionOperations.scala | 34 +++++++++---------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 9160f253d1753..567a98d8324b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -4097,11 +4097,11 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike resultArray case _ => val hs = new OpenHashSet[Any] - var exceptNullElement = true + var notFoundNullElement = true var i = 0 while (i < array2.numElements()) { if (array2.isNullAt(i)) { - exceptNullElement = false + notFoundNullElement = false } else { val elem = array2.get(i, elementType) hs.add(elem) @@ -4112,9 +4112,9 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike i = 0 while (i < array1.numElements()) { if (array1.isNullAt(i)) { - if (exceptNullElement) { + if (notFoundNullElement) { arrayBuffer += null - exceptNullElement = false + notFoundNullElement = false } } else { val elem = array1.get(i, elementType) @@ -4165,7 +4165,7 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike nullSafeCodeGen(ctx, ev, (array1, array2) => { if (openHashElementType != "") { // Here, we ensure elementTypeSupportEquals is true - val exceptNullElement = ctx.freshName("exceptNullElement") + val notFoundNullElement = ctx.freshName("notFoundNullElement") val openHashSet = classOf[OpenHashSet[_]].getName val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$openHashElementType()" val hs = ctx.freshName("hs") @@ -4175,20 +4175,20 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike val arrayDataIdx = ctx.freshName("arrayDataIdx") s""" |$openHashSet $hs = new $openHashSet$postFix($classTag); - |boolean $exceptNullElement = true; + |boolean $notFoundNullElement = true; |int $size = 0; |for (int $i = 0; $i < $array2.numElements(); $i++) { | if ($array2.isNullAt($i)) { - | $exceptNullElement = false; + | $notFoundNullElement = false; | } else { | $hs.add$postFix($array2.$getter); | } |} |for (int $i = 0; $i < $array1.numElements(); $i++) { | if ($array1.isNullAt($i)) { - | if ($exceptNullElement) { + | if ($notFoundNullElement) { | $size++; - | $exceptNullElement = false; + | $notFoundNullElement = false; | } | } else { | $javaTypeName $value = $array1.$getter; @@ -4200,20 +4200,20 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike |} |$arrayBuilder |$hs = new $openHashSet$postFix($classTag); - |$exceptNullElement = true; + |$notFoundNullElement = true; |int $pos = 0; |for (int $i = 0; $i < $array2.numElements(); $i++) { | if ($array2.isNullAt($i)) { - | $exceptNullElement = false; + | $notFoundNullElement = false; | } else { | $hs.add$postFix($array2.$getter); | } |} |for (int $i = 0; $i < $array1.numElements(); $i++) { | if ($array1.isNullAt($i)) { - | if ($exceptNullElement) { + | if ($notFoundNullElement) { | ${ev.value}.setNullAt($pos++); - | $exceptNullElement = false; + | $notFoundNullElement = false; | } | } else { | $javaTypeName $value = $array1.$getter; @@ -4245,20 +4245,20 @@ object ArrayExcept { elementType: DataType, ordering: Ordering[Any]): ArrayData = { val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] - var exceptNullElement = false + var scannedNullElements = false var i = 0 while (i < array1.numElements()) { var found = false var elem1 = array1.get(i, elementType) if (array1.isNullAt(i)) { - if (!exceptNullElement) { + if (!scannedNullElements) { var j = 0 while (!found && j < array2.numElements()) { found = array2.isNullAt(j) j += 1 } - // array2 is scaned only once for null element - exceptNullElement = true + // array2 is scanned only once for null element + scannedNullElements = true } else { found = true } From 02b809bf11fe225ae5de16cbe8e2cda37d0d56b4 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 26 Jul 2018 21:11:28 +0100 Subject: [PATCH 13/29] address review comments --- .../expressions/collectionOperations.scala | 332 ++++++------------ .../CollectionExpressionsSuite.scala | 6 - 2 files changed, 116 insertions(+), 222 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 567a98d8324b6..c37af901f284d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -4055,237 +4055,137 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike pos } - override def nullSafeEval(input1: Any, input2: Any): Any = { - val array1 = input1.asInstanceOf[ArrayData] - val array2 = input2.asInstanceOf[ArrayData] - - if (elementTypeSupportEquals) { - elementType match { - case IntegerType => - // avoid boxing of primitive int array elements - // calculate result array size - hsInt = new OpenHashSet[Int] - val elements = evalIntLongPrimitiveType(array1, array2, null, false) - // allocate result array - hsInt = new OpenHashSet[Int] - val resultArray = if (UnsafeArrayData.shouldUseGenericArrayData( - IntegerType.defaultSize, elements)) { - new GenericArrayData(new Array[Any](elements)) - } else { - UnsafeArrayData.forPrimitiveArray( - Platform.INT_ARRAY_OFFSET, elements, IntegerType.defaultSize) - } - // assign elements into the result array - evalIntLongPrimitiveType(array1, array2, resultArray, false) - resultArray - case LongType => - // avoid boxing of primitive long array elements - // calculate result array size - hsLong = new OpenHashSet[Long] - val elements = evalIntLongPrimitiveType(array1, array2, null, true) - // allocate result array - hsLong = new OpenHashSet[Long] - val resultArray = if (UnsafeArrayData.shouldUseGenericArrayData( - LongType.defaultSize, elements)) { - new GenericArrayData(new Array[Any](elements)) - } else { - UnsafeArrayData.forPrimitiveArray( - Platform.LONG_ARRAY_OFFSET, elements, LongType.defaultSize) - } - // assign elements into the result array - evalIntLongPrimitiveType(array1, array2, resultArray, true) - resultArray - case _ => - val hs = new OpenHashSet[Any] - var notFoundNullElement = true - var i = 0 - while (i < array2.numElements()) { - if (array2.isNullAt(i)) { - notFoundNullElement = false + val exceptEquals: (ArrayData, ArrayData) => ArrayData = { + (array1: ArrayData, array2: ArrayData) => + if (elementTypeSupportEquals) { + elementType match { + case IntegerType => + // avoid boxing of primitive int array elements + // calculate result array size + hsInt = new OpenHashSet[Int] + val elements = evalIntLongPrimitiveType(array1, array2, null, false) + // allocate result array + hsInt = new OpenHashSet[Int] + val resultArray = if (UnsafeArrayData.shouldUseGenericArrayData( + IntegerType.defaultSize, elements)) { + new GenericArrayData(new Array[Any](elements)) } else { - val elem = array2.get(i, elementType) - hs.add(elem) + UnsafeArrayData.forPrimitiveArray( + Platform.INT_ARRAY_OFFSET, elements, IntegerType.defaultSize) } - i += 1 - } - val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] - i = 0 - while (i < array1.numElements()) { - if (array1.isNullAt(i)) { - if (notFoundNullElement) { - arrayBuffer += null + // assign elements into the result array + evalIntLongPrimitiveType(array1, array2, resultArray, false) + resultArray + case LongType => + // avoid boxing of primitive long array elements + // calculate result array size + hsLong = new OpenHashSet[Long] + val elements = evalIntLongPrimitiveType(array1, array2, null, true) + // allocate result array + hsLong = new OpenHashSet[Long] + val resultArray = if (UnsafeArrayData.shouldUseGenericArrayData( + LongType.defaultSize, elements)) { + new GenericArrayData(new Array[Any](elements)) + } else { + UnsafeArrayData.forPrimitiveArray( + Platform.LONG_ARRAY_OFFSET, elements, LongType.defaultSize) + } + // assign elements into the result array + evalIntLongPrimitiveType(array1, array2, resultArray, true) + resultArray + case _ => + val hs = new OpenHashSet[Any] + var notFoundNullElement = true + var i = 0 + while (i < array2.numElements()) { + if (array2.isNullAt(i)) { notFoundNullElement = false + } else { + val elem = array2.get(i, elementType) + hs.add(elem) + } + i += 1 + } + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + i = 0 + while (i < array1.numElements()) { + if (array1.isNullAt(i)) { + if (notFoundNullElement) { + arrayBuffer += null + notFoundNullElement = false + } + } else { + val elem = array1.get(i, elementType) + if (!hs.contains(elem)) { + arrayBuffer += elem + hs.add(elem) + } } + i += 1 + } + new GenericArrayData(arrayBuffer) + } + } else { + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + var scannedNullElements = false + var i = 0 + while (i < array1.numElements()) { + var found = false + val elem1 = array1.get(i, elementType) + if (elem1 == null) { + if (!scannedNullElements) { + var j = 0 + while (!found && j < array2.numElements()) { + found = array2.isNullAt(j) + j += 1 + } + // array2 is scanned only once for null element + scannedNullElements = true } else { - val elem = array1.get(i, elementType) - if (!hs.contains(elem)) { - arrayBuffer += elem - hs.add(elem) + found = true + } + } else { + var j = 0 + while (!found && j < array2.numElements()) { + val elem2 = array2.get(j, elementType) + if (elem2 != null) { + found = ordering.equiv(elem1, elem2) + } + j += 1 + } + if (!found) { + // check whether elem1 is already stored in arrayBuffer + var k = 0 + while (!found && k < arrayBuffer.size) { + val va = arrayBuffer(k) + found = (va != null) && ordering.equiv(va, elem1) + k += 1 } } - i += 1 } - new GenericArrayData(arrayBuffer) + if (!found) { + arrayBuffer += elem1 + } + i += 1 + } + new GenericArrayData(arrayBuffer) } - } else { - ArrayExcept.exceptOrdering(array1, array2, elementType, ordering) - } } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val i = ctx.freshName("i") - val pos = ctx.freshName("pos") - val value = ctx.freshName("value") - val size = ctx.freshName("size") - val (postFix, openHashElementType, getter, setter, javaTypeName, castOp, arrayBuilder) = - if (elementTypeSupportEquals) { - elementType match { - case ByteType | ShortType | IntegerType | LongType => - val ptName = CodeGenerator.primitiveTypeName(elementType) - val unsafeArray = ctx.freshName("unsafeArray") - (if (elementType == LongType) s"$$mcJ$$sp" else s"$$mcI$$sp", - if (elementType == LongType) "Long" else "Int", - s"get$ptName($i)", s"set$ptName($pos, $value)", CodeGenerator.javaType(elementType), - if (elementType == LongType) "(long)" else "(int)", - s""" - |${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")} - |${ev.value} = $unsafeArray; - """.stripMargin) - case _ => - val genericArrayData = classOf[GenericArrayData].getName - val et = ctx.addReferenceObj("elementType", elementType) - ("", "Object", - s"get($i, $et)", s"update($pos, $value)", "Object", "", - s"${ev.value} = new $genericArrayData(new Object[$size]);") - } - } else { - ("", "", "", "", "", "", "") - } + override def nullSafeEval(input1: Any, input2: Any): Any = { + val array1 = input1.asInstanceOf[ArrayData] + val array2 = input2.asInstanceOf[ArrayData] + + exceptEquals(array1, array2) + } + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val arrayData = classOf[ArrayData].getName + val expr = ctx.addReferenceObj("arrayExceptExpr", this) nullSafeCodeGen(ctx, ev, (array1, array2) => { - if (openHashElementType != "") { - // Here, we ensure elementTypeSupportEquals is true - val notFoundNullElement = ctx.freshName("notFoundNullElement") - val openHashSet = classOf[OpenHashSet[_]].getName - val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$openHashElementType()" - val hs = ctx.freshName("hs") - val arrayData = classOf[ArrayData].getName - val arrays = ctx.freshName("arrays") - val array = ctx.freshName("array") - val arrayDataIdx = ctx.freshName("arrayDataIdx") - s""" - |$openHashSet $hs = new $openHashSet$postFix($classTag); - |boolean $notFoundNullElement = true; - |int $size = 0; - |for (int $i = 0; $i < $array2.numElements(); $i++) { - | if ($array2.isNullAt($i)) { - | $notFoundNullElement = false; - | } else { - | $hs.add$postFix($array2.$getter); - | } - |} - |for (int $i = 0; $i < $array1.numElements(); $i++) { - | if ($array1.isNullAt($i)) { - | if ($notFoundNullElement) { - | $size++; - | $notFoundNullElement = false; - | } - | } else { - | $javaTypeName $value = $array1.$getter; - | if (!$hs.contains($castOp $value)) { - | $hs.add$postFix($value); - | $size++; - | } - | } - |} - |$arrayBuilder - |$hs = new $openHashSet$postFix($classTag); - |$notFoundNullElement = true; - |int $pos = 0; - |for (int $i = 0; $i < $array2.numElements(); $i++) { - | if ($array2.isNullAt($i)) { - | $notFoundNullElement = false; - | } else { - | $hs.add$postFix($array2.$getter); - | } - |} - |for (int $i = 0; $i < $array1.numElements(); $i++) { - | if ($array1.isNullAt($i)) { - | if ($notFoundNullElement) { - | ${ev.value}.setNullAt($pos++); - | $notFoundNullElement = false; - | } - | } else { - | $javaTypeName $value = $array1.$getter; - | if (!$hs.contains($castOp $value)) { - | $hs.add$postFix($value); - | ${ev.value}.$setter; - | $pos++; - | } - | } - |} - """.stripMargin - } else { - val arrayExcept = classOf[ArrayExcept].getName - val et = ctx.addReferenceObj("elementTypeIntersect", elementType) - val order = ctx.addReferenceObj("orderingIntersect", ordering) - val method = "exceptOrdering" - s"${ev.value} = $arrayExcept$$.MODULE$$.$method($array1, $array2, $et, $order);" - } + s"${ev.value} = ($arrayData)$expr.nullSafeEval($array1, $array2);" }) } override def prettyName: String = "array_except" } - -object ArrayExcept { - def exceptOrdering( - array1: ArrayData, - array2: ArrayData, - elementType: DataType, - ordering: Ordering[Any]): ArrayData = { - val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] - var scannedNullElements = false - var i = 0 - while (i < array1.numElements()) { - var found = false - var elem1 = array1.get(i, elementType) - if (array1.isNullAt(i)) { - if (!scannedNullElements) { - var j = 0 - while (!found && j < array2.numElements()) { - found = array2.isNullAt(j) - j += 1 - } - // array2 is scanned only once for null element - scannedNullElements = true - } else { - found = true - } - } else { - var j = 0 - while (!found && j < array2.numElements()) { - if (!array2.isNullAt(j)) { - val elem2 = array2.get(j, elementType) - found = ordering.equiv(elem1, elem2) - } - j += 1 - } - if (!found) { - // check whether elem1 is already stored in arrayBuffer - var k = 0 - while (!found && k < arrayBuffer.size) { - val va = arrayBuffer(k) - found = (va != null) && ordering.equiv(va, elem1) - k += 1 - } - } - } - if (!found) { - arrayBuffer += elem1 - } - i += 1 - } - new GenericArrayData(arrayBuffer) - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 603f9889e6809..29852224ca3ba 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1512,10 +1512,6 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val a04 = Literal.create(Seq(1, 2, null, 4, 5, 1), ArrayType(IntegerType, true)) val a05 = Literal.create(Seq(-5, 4, null, 2, -1), ArrayType(IntegerType, true)) val a06 = Literal.create(Seq.empty[Int], ArrayType(IntegerType, false)) - val ab0 = Literal.create(Seq[Byte](1, 2, 3, 2), ArrayType(ByteType, false)) - val ab1 = Literal.create(Seq[Byte](4, 2, 4), ArrayType(ByteType, false)) - val as0 = Literal.create(Seq[Short](1, 2, 3, 2), ArrayType(ShortType, false)) - val as1 = Literal.create(Seq[Short](4, 2, 4), ArrayType(ShortType, false)) val a10 = Literal.create(Seq(1L, 2L, 4L, 3L), ArrayType(LongType, false)) val a11 = Literal.create(Seq(4L, 2L), ArrayType(LongType, false)) @@ -1544,8 +1540,6 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayExcept(a04, a05), Seq(1, 5)) checkEvaluation(ArrayExcept(a04, a06), Seq(1, 2, null, 4, 5)) checkEvaluation(ArrayExcept(a06, a04), Seq.empty) - checkEvaluation(ArrayExcept(ab0, ab1), Seq[Byte](1, 3)) - checkEvaluation(ArrayExcept(as0, as1), Seq[Short](1, 3)) checkEvaluation(ArrayExcept(a10, a11), Seq(1L, 3L)) checkEvaluation(ArrayExcept(a12, a11), Seq(1L)) From 0d24fa95edf3cf8ec18a7bca34382e8aff24e505 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 27 Jul 2018 04:34:26 +0100 Subject: [PATCH 14/29] use better method name --- .../spark/sql/catalyst/expressions/collectionOperations.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index c37af901f284d..afb40afd7b097 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -4055,7 +4055,7 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike pos } - val exceptEquals: (ArrayData, ArrayData) => ArrayData = { + val evalExcept: (ArrayData, ArrayData) => ArrayData = { (array1: ArrayData, array2: ArrayData) => if (elementTypeSupportEquals) { elementType match { @@ -4176,7 +4176,7 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike val array1 = input1.asInstanceOf[ArrayData] val array2 = input2.asInstanceOf[ArrayData] - exceptEquals(array1, array2) + evalExcept(array1, array2) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { From 5115a49959eb8c50163e1d4166ac991f73909dcf Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 27 Jul 2018 05:10:35 +0100 Subject: [PATCH 15/29] address review comment --- .../expressions/collectionOperations.scala | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index afb40afd7b097..725c49b3d55e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -4055,11 +4055,11 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike pos } - val evalExcept: (ArrayData, ArrayData) => ArrayData = { - (array1: ArrayData, array2: ArrayData) => - if (elementTypeSupportEquals) { - elementType match { - case IntegerType => + @transient lazy val evalExcept: (ArrayData, ArrayData) => ArrayData = { + if (elementTypeSupportEquals) { + elementType match { + case IntegerType => + (array1, array2) => // avoid boxing of primitive int array elements // calculate result array size hsInt = new OpenHashSet[Int] @@ -4076,7 +4076,8 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike // assign elements into the result array evalIntLongPrimitiveType(array1, array2, resultArray, false) resultArray - case LongType => + case LongType => + (array1, array2) => // avoid boxing of primitive long array elements // calculate result array size hsLong = new OpenHashSet[Long] @@ -4093,7 +4094,8 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike // assign elements into the result array evalIntLongPrimitiveType(array1, array2, resultArray, true) resultArray - case _ => + case _ => + (array1, array2) => val hs = new OpenHashSet[Any] var notFoundNullElement = true var i = 0 @@ -4124,8 +4126,9 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike i += 1 } new GenericArrayData(arrayBuffer) - } - } else { + } + } else { + (array1, array2) => val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] var scannedNullElements = false var i = 0 From e03876632f6f354cb437f3cd0ef9dd016d1ae542 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 27 Jul 2018 16:17:05 +0100 Subject: [PATCH 16/29] address review comments --- .../spark/util/collection/OpenHashSet.scala | 2 +- .../util/collection/OpenHashSetSuite.scala | 74 ++++++ .../expressions/collectionOperations.scala | 227 ++++++++++++------ 3 files changed, 230 insertions(+), 73 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala index 60f6f537c1d54..12a4eeb4abcc4 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala @@ -41,7 +41,7 @@ import org.apache.spark.annotation.Private * to explore all spaces for each key (see http://en.wikipedia.org/wiki/Quadratic_probing). */ @Private -class OpenHashSet[@specialized(Long, Int) T: ClassTag]( +class OpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag]( initialCapacity: Int, loadFactor: Double) extends Serializable { diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala index 210bc5c099742..b887f937a9da9 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala @@ -112,6 +112,80 @@ class OpenHashSetSuite extends SparkFunSuite with Matchers { assert(!set.contains(10000L)) } + test("primitive float") { + val set = new OpenHashSet[Float] + assert(set.size === 0) + assert(!set.contains(10.1F)) + assert(!set.contains(50.5F)) + assert(!set.contains(999.9F)) + assert(!set.contains(10000.1F)) + + set.add(10.1F) + assert(set.size === 1) + assert(set.contains(10.1F)) + assert(!set.contains(50.5F)) + assert(!set.contains(999.9F)) + assert(!set.contains(10000.1F)) + + set.add(50.5F) + assert(set.size === 2) + assert(set.contains(10.1F)) + assert(set.contains(50.5F)) + assert(!set.contains(999.9F)) + assert(!set.contains(10000.1F)) + + set.add(999.9F) + assert(set.size === 3) + assert(set.contains(10.1F)) + assert(set.contains(50.5F)) + assert(set.contains(999.9F)) + assert(!set.contains(10000.1F)) + + set.add(50.5F) + assert(set.size === 3) + assert(set.contains(10.1F)) + assert(set.contains(50.5F)) + assert(set.contains(999.9F)) + assert(!set.contains(10000.1F)) + } + + test("primitive double") { + val set = new OpenHashSet[Double] + assert(set.size === 0) + assert(!set.contains(10.1D)) + assert(!set.contains(50.5D)) + assert(!set.contains(999.9D)) + assert(!set.contains(10000.1D)) + + set.add(10.1D) + assert(set.size === 1) + assert(set.contains(10.1D)) + assert(!set.contains(50.5D)) + assert(!set.contains(999.9D)) + assert(!set.contains(10000.1D)) + + set.add(50.5D) + assert(set.size === 2) + assert(set.contains(10.1D)) + assert(set.contains(50.5D)) + assert(!set.contains(999.9D)) + assert(!set.contains(10000.1D)) + + set.add(999.9D) + assert(set.size === 3) + assert(set.contains(10.1D)) + assert(set.contains(50.5D)) + assert(set.contains(999.9D)) + assert(!set.contains(10000.1D)) + + set.add(50.5D) + assert(set.size === 3) + assert(set.contains(10.1D)) + assert(set.contains(50.5D)) + assert(set.contains(999.9D)) + assert(!set.contains(10000.1D)) + } + test("non-primitive") { val set = new OpenHashSet[String] assert(set.size === 0) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 725c49b3d55e1..3b0d424390324 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3977,8 +3977,12 @@ object ArrayUnion { array(2) """, since = "2.4.0") -case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike { - override def dataType: DataType = left.dataType +case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike + with ComplexTypeMergingExpression { + override def dataType: DataType = { + dataTypeCheck + left.dataType + } var hsInt: OpenHashSet[Int] = _ var hsLong: OpenHashSet[Long] = _ @@ -4057,76 +4061,37 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike @transient lazy val evalExcept: (ArrayData, ArrayData) => ArrayData = { if (elementTypeSupportEquals) { - elementType match { - case IntegerType => - (array1, array2) => - // avoid boxing of primitive int array elements - // calculate result array size - hsInt = new OpenHashSet[Int] - val elements = evalIntLongPrimitiveType(array1, array2, null, false) - // allocate result array - hsInt = new OpenHashSet[Int] - val resultArray = if (UnsafeArrayData.shouldUseGenericArrayData( - IntegerType.defaultSize, elements)) { - new GenericArrayData(new Array[Any](elements)) - } else { - UnsafeArrayData.forPrimitiveArray( - Platform.INT_ARRAY_OFFSET, elements, IntegerType.defaultSize) - } - // assign elements into the result array - evalIntLongPrimitiveType(array1, array2, resultArray, false) - resultArray - case LongType => - (array1, array2) => - // avoid boxing of primitive long array elements - // calculate result array size - hsLong = new OpenHashSet[Long] - val elements = evalIntLongPrimitiveType(array1, array2, null, true) - // allocate result array - hsLong = new OpenHashSet[Long] - val resultArray = if (UnsafeArrayData.shouldUseGenericArrayData( - LongType.defaultSize, elements)) { - new GenericArrayData(new Array[Any](elements)) - } else { - UnsafeArrayData.forPrimitiveArray( - Platform.LONG_ARRAY_OFFSET, elements, LongType.defaultSize) - } - // assign elements into the result array - evalIntLongPrimitiveType(array1, array2, resultArray, true) - resultArray - case _ => - (array1, array2) => - val hs = new OpenHashSet[Any] - var notFoundNullElement = true - var i = 0 - while (i < array2.numElements()) { - if (array2.isNullAt(i)) { - notFoundNullElement = false - } else { - val elem = array2.get(i, elementType) - hs.add(elem) - } - i += 1 + (array1, array2) => + val hs = new OpenHashSet[Any] + var notFoundNullElement = true + var i = 0 + while (i < array2.numElements()) { + if (array2.isNullAt(i)) { + notFoundNullElement = false + } else { + val elem = array2.get(i, elementType) + hs.add(elem) + } + i += 1 + } + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + i = 0 + while (i < array1.numElements()) { + if (array1.isNullAt(i)) { + if (notFoundNullElement) { + arrayBuffer += null + notFoundNullElement = false } - val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] - i = 0 - while (i < array1.numElements()) { - if (array1.isNullAt(i)) { - if (notFoundNullElement) { - arrayBuffer += null - notFoundNullElement = false - } - } else { - val elem = array1.get(i, elementType) - if (!hs.contains(elem)) { - arrayBuffer += elem - hs.add(elem) - } - } - i += 1 + } else { + val elem = array1.get(i, elementType) + if (!hs.contains(elem)) { + arrayBuffer += elem + hs.add(elem) } - new GenericArrayData(arrayBuffer) - } + } + i += 1 + } + new GenericArrayData(arrayBuffer) } else { (array1, array2) => val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] @@ -4184,9 +4149,127 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val arrayData = classOf[ArrayData].getName - val expr = ctx.addReferenceObj("arrayExceptExpr", this) + val i = ctx.freshName("i") + val pos = ctx.freshName("pos") + val value = ctx.freshName("value") + val hsValue = ctx.freshName("hsValue") + val size = ctx.freshName("size") + val (postFix, openHashElementType, hsJavaTypeName, genHsValue, + getter, setter, javaTypeName, arrayBuilder) = + if (elementTypeSupportEquals) { + elementType match { + case BooleanType | ByteType | ShortType | IntegerType => + val ptName = CodeGenerator.primitiveTypeName(elementType) + val unsafeArray = ctx.freshName("unsafeArray") + ("$mcI$sp", "Int", "int", + if (elementType != BooleanType) { + s"(int) $value" + } else { + s"$value ? 1 : 0;" + }, + s"get$ptName($i)", s"set$ptName($pos, $value)", CodeGenerator.javaType(elementType), + s""" + |${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")} + |${ev.value} = $unsafeArray; + """.stripMargin) + case LongType | FloatType | DoubleType => + val ptName = CodeGenerator.primitiveTypeName(elementType) + val unsafeArray = ctx.freshName("unsafeArray") + val signature = elementType match { + case LongType => "$mcJ$sp" + case FloatType => "$mcF$sp" + case DoubleType => "$mcD$sp" + } + (signature, CodeGenerator.boxedType(elementType), CodeGenerator.javaType(elementType), value, + s"get$ptName($i)", s"set$ptName($pos, $value)", CodeGenerator.javaType(elementType), + s""" + |${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")} + |${ev.value} = $unsafeArray; + """.stripMargin) + case _ => + val genericArrayData = classOf[GenericArrayData].getName + val et = ctx.addReferenceObj("elementType", elementType) + ("", "Object", "Object", value, + s"get($i, $et)", s"update($pos, $value)", "Object", + s"${ev.value} = new $genericArrayData(new Object[$size]);") + } + } else { + ("", "", "", "", "", "", "", "") + } + nullSafeCodeGen(ctx, ev, (array1, array2) => { - s"${ev.value} = ($arrayData)$expr.nullSafeEval($array1, $array2);" + if (openHashElementType != "") { + // Here, we ensure elementTypeSupportEquals is true + val notFoundNullElement = ctx.freshName("notFoundNullElement") + val openHashSet = classOf[OpenHashSet[_]].getName + val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$openHashElementType()" + val hs = ctx.freshName("hs") + val arrayData = classOf[ArrayData].getName + val arrays = ctx.freshName("arrays") + val array = ctx.freshName("array") + val arrayDataIdx = ctx.freshName("arrayDataIdx") + s""" + |$openHashSet $hs = new $openHashSet$postFix($classTag); + |boolean $notFoundNullElement = true; + |int $size = 0; + |for (int $i = 0; $i < $array2.numElements(); $i++) { + | if ($array2.isNullAt($i)) { + | $notFoundNullElement = false; + | } else { + | $javaTypeName $value = $array2.$getter; + | $hsJavaTypeName $hsValue = $genHsValue; + | $hs.add$postFix($hsValue); + | } + |} + |for (int $i = 0; $i < $array1.numElements(); $i++) { + | if ($array1.isNullAt($i)) { + | if ($notFoundNullElement) { + | $size++; + | $notFoundNullElement = false; + | } + | } else { + | $javaTypeName $value = $array1.$getter; + | $hsJavaTypeName $hsValue = $genHsValue; + | if (!$hs.contains($hsValue)) { + | $hs.add$postFix($hsValue); + | $size++; + | } + | } + |} + |$arrayBuilder + |$hs = new $openHashSet$postFix($classTag); + |$notFoundNullElement = true; + |int $pos = 0; + |for (int $i = 0; $i < $array2.numElements(); $i++) { + | if ($array2.isNullAt($i)) { + | $notFoundNullElement = false; + | } else { + | $javaTypeName $value = $array2.$getter; + | $hsJavaTypeName $hsValue = $genHsValue; + | $hs.add$postFix($hsValue); + | } + |} + |for (int $i = 0; $i < $array1.numElements(); $i++) { + | if ($array1.isNullAt($i)) { + | if ($notFoundNullElement) { + | ${ev.value}.setNullAt($pos++); + | $notFoundNullElement = false; + | } + | } else { + | $javaTypeName $value = $array1.$getter; + | $hsJavaTypeName $hsValue = $genHsValue; + | if (!$hs.contains($hsValue)) { + | $hs.add$postFix($hsValue); + | ${ev.value}.$setter; + | $pos++; + | } + | } + |} + """.stripMargin + } else { + val expr = ctx.addReferenceObj("arrayExceptExpr", this) + s"${ev.value} = ($arrayData)$expr.nullSafeEval($array1, $array2);" + } }) } From 447d274121806e6c4098b9e2762452cfc3ccd838 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 27 Jul 2018 16:25:43 +0100 Subject: [PATCH 17/29] add a missing file --- .../expressions/CollectionExpressionsSuite.scala | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 29852224ca3ba..2f6f9064f9e62 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1512,6 +1512,16 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val a04 = Literal.create(Seq(1, 2, null, 4, 5, 1), ArrayType(IntegerType, true)) val a05 = Literal.create(Seq(-5, 4, null, 2, -1), ArrayType(IntegerType, true)) val a06 = Literal.create(Seq.empty[Int], ArrayType(IntegerType, false)) + val abl0 = Literal.create(Seq[Boolean](true, true), ArrayType(BooleanType, false)) + val abl1 = Literal.create(Seq[Boolean](false, false), ArrayType(BooleanType, false)) + val ab0 = Literal.create(Seq[Byte](1, 2, 3, 2), ArrayType(ByteType, false)) + val ab1 = Literal.create(Seq[Byte](4, 2, 4), ArrayType(ByteType, false)) + val as0 = Literal.create(Seq[Short](1, 2, 3, 2), ArrayType(ShortType, false)) + val as1 = Literal.create(Seq[Short](4, 2, 4), ArrayType(ShortType, false)) + val af0 = Literal.create(Seq[Float](1.1F, 2.2F, 3.3F, 2.2F), ArrayType(FloatType, false)) + val af1 = Literal.create(Seq[Float](4.4F, 2.2F, 4.4F), ArrayType(FloatType, false)) + val ad0 = Literal.create(Seq[Double](1.1, 2.2, 3.3, 2.2), ArrayType(DoubleType, false)) + val ad1 = Literal.create(Seq[Double](4.4, 2.2, 4.4), ArrayType(DoubleType, false)) val a10 = Literal.create(Seq(1L, 2L, 4L, 3L), ArrayType(LongType, false)) val a11 = Literal.create(Seq(4L, 2L), ArrayType(LongType, false)) @@ -1540,6 +1550,11 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayExcept(a04, a05), Seq(1, 5)) checkEvaluation(ArrayExcept(a04, a06), Seq(1, 2, null, 4, 5)) checkEvaluation(ArrayExcept(a06, a04), Seq.empty) + checkEvaluation(ArrayExcept(abl0, abl1), Seq[Boolean](true)) + checkEvaluation(ArrayExcept(ab0, ab1), Seq[Byte](1, 3)) + checkEvaluation(ArrayExcept(as0, as1), Seq[Short](1, 3)) + checkEvaluation(ArrayExcept(af0, af1), Seq[Float](1.1F, 3.3F)) + checkEvaluation(ArrayExcept(ad0, ad1), Seq[Double](1.1, 3.3)) checkEvaluation(ArrayExcept(a10, a11), Seq(1L, 3L)) checkEvaluation(ArrayExcept(a12, a11), Seq(1L)) From 0b0a2ae3f9d6809e9f269a94182b27361bfc4563 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 27 Jul 2018 16:59:23 +0100 Subject: [PATCH 18/29] rebase with master --- .../sql/catalyst/expressions/collectionOperations.scala | 4 +++- .../scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala | 5 ++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 3b0d424390324..e0d164d5dac11 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -69,6 +69,7 @@ trait BinaryArrayExpressionWithImplicitCast extends BinaryExpression } } + /** * Given an array or map, returns total number of elements in it. */ @@ -953,6 +954,7 @@ case class MapFromEntries(child: Expression) extends UnaryExpression { override def prettyName: String = "map_from_entries" } + /** * Common base class for [[SortArray]] and [[ArraySort]]. */ @@ -2370,7 +2372,7 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio ByteArray.concat(inputs: _*) case StringType => val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) - UTF8String.concat(inputs: _*) + UTF8String.concat(inputs : _*) case ArrayType(elementType, _) => val inputs = children.toStream.map(_.eval(input)) if (inputs.contains(null)) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 0ea444ca14709..e550b142c738d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1173,8 +1173,8 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } test("concat function - arrays") { - val nseqi: Seq[Int] = null - val nseqs: Seq[String] = null + val nseqi : Seq[Int] = null + val nseqs : Seq[String] = null val df = Seq( (Seq(1), Seq(2, 3), Seq(5L, 6L), nseqi, Seq("a", "b", "c"), Seq("d", "e"), Seq("f"), nseqs), (Seq(1, 0), Seq.empty[Int], Seq(2L), nseqi, Seq("a"), Seq.empty[String], Seq(null), nseqs) @@ -1513,7 +1513,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } -<<<<<<< HEAD // Shuffle expressions should produce same results at retries in the same DataFrame. private def checkShuffleResult(df: DataFrame): Unit = { checkAnswer(df, df.collect()) From d8e5d798498be7dd976835499004b539cde09005 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 27 Jul 2018 17:10:08 +0100 Subject: [PATCH 19/29] fix scala style error --- .../spark/sql/catalyst/expressions/collectionOperations.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index e0d164d5dac11..c75f27858083c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -4182,7 +4182,8 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike case FloatType => "$mcF$sp" case DoubleType => "$mcD$sp" } - (signature, CodeGenerator.boxedType(elementType), CodeGenerator.javaType(elementType), value, + (signature, CodeGenerator.boxedType(elementType), + CodeGenerator.javaType(elementType), value, s"get$ptName($i)", s"set$ptName($pos, $value)", CodeGenerator.javaType(elementType), s""" |${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")} From 4a70e0e1d7ba4bf63d73a025eee262842f9bcaf9 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 27 Jul 2018 18:10:20 +0100 Subject: [PATCH 20/29] fix compilation error --- .../apache/spark/sql/catalyst/expressions/Expression.scala | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index dcb9c96ca3b2d..773aefc0ac1f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -709,7 +709,7 @@ trait ComplexTypeMergingExpression extends Expression { @transient lazy val inputTypesForMerging: Seq[DataType] = children.map(_.dataType) - override def dataType: DataType = { + def dataTypeCheck: Unit = { require( inputTypesForMerging.nonEmpty, "The collection of input data types must not be empty.") @@ -717,6 +717,10 @@ trait ComplexTypeMergingExpression extends Expression { TypeCoercion.haveSameType(inputTypesForMerging), "All input types must be the same except nullable, containsNull, valueContainsNull flags." + s" The input types found are\n\t${inputTypesForMerging.mkString("\n\t")}") + } + + override def dataType: DataType = { + dataTypeCheck inputTypesForMerging.reduceLeft(TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(_, _).get) } } From 548f336f9cd84cbc5da276335e2c335d8ba64ab7 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 27 Jul 2018 18:54:25 +0100 Subject: [PATCH 21/29] drop unused code --- .../expressions/collectionOperations.scala | 75 ------------------- 1 file changed, 75 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index c75f27858083c..504eb29156533 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3986,81 +3986,6 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike left.dataType } - var hsInt: OpenHashSet[Int] = _ - var hsLong: OpenHashSet[Long] = _ - - def assignInt(array: ArrayData, idx: Int, resultArray: ArrayData, pos: Int): Boolean = { - val elem = array.getInt(idx) - if (!hsInt.contains(elem)) { - if (resultArray != null) { - resultArray.setInt(pos, elem) - } - hsInt.add(elem) - true - } else { - false - } - } - - def assignLong(array: ArrayData, idx: Int, resultArray: ArrayData, pos: Int): Boolean = { - val elem = array.getLong(idx) - if (!hsLong.contains(elem)) { - if (resultArray != null) { - resultArray.setLong(pos, elem) - } - hsLong.add(elem) - true - } else { - false - } - } - - def evalIntLongPrimitiveType( - array1: ArrayData, - array2: ArrayData, - resultArray: ArrayData, - isLongType: Boolean): Int = { - // store elements into resultArray - var notFoundNullElement = true - var i = 0 - while (i < array2.numElements()) { - if (array2.isNullAt(i)) { - notFoundNullElement = false - } else { - val assigned = if (!isLongType) { - hsInt.add(array2.getInt(i)) - } else { - hsLong.add(array2.getLong(i)) - } - } - i += 1 - } - var pos = 0 - i = 0 - while (i < array1.numElements()) { - if (array1.isNullAt(i)) { - if (notFoundNullElement) { - if (resultArray != null) { - resultArray.setNullAt(pos) - } - pos += 1 - notFoundNullElement = false - } - } else { - val assigned = if (!isLongType) { - assignInt(array1, i, resultArray, pos) - } else { - assignLong(array1, i, resultArray, pos) - } - if (assigned) { - pos += 1 - } - } - i += 1 - } - pos - } - @transient lazy val evalExcept: (ArrayData, ArrayData) => ArrayData = { if (elementTypeSupportEquals) { (array1, array2) => From f4d4f8c89ce82ed058f6696885363f0e00154eaf Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 27 Jul 2018 19:01:41 +0100 Subject: [PATCH 22/29] update comment --- .../org/apache/spark/util/collection/OpenHashSet.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala index 12a4eeb4abcc4..dda92aee7bdd9 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala @@ -28,9 +28,9 @@ import org.apache.spark.annotation.Private * removed. * * The underlying implementation uses Scala compiler's specialization to generate optimized - * storage for two primitive types (Long and Int). It is much faster than Java's standard HashSet - * while incurring much less memory overhead. This can serve as building blocks for higher level - * data structures such as an optimized HashMap. + * storage for four primitive types (Long, Int, Double, and Float). It is much faster than Java's + * standard HashSet while incurring much less memory overhead. This can serve as building blocks + * for higher level data structures such as an optimized HashMap. * * This OpenHashSet is designed to serve as building blocks for higher level data structures * such as an optimized hash map. Compared with standard hash set implementations, this class From 6ef1f22aa68d52cf0c00b21211e19d3f80bab7c6 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 27 Jul 2018 20:32:07 +0100 Subject: [PATCH 23/29] update hasher --- .../spark/util/collection/OpenHashSet.scala | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala index dda92aee7bdd9..8883e17bf3164 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala @@ -77,6 +77,10 @@ class OpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag]( (new LongHasher).asInstanceOf[Hasher[T]] } else if (mt == ClassTag.Int) { (new IntHasher).asInstanceOf[Hasher[T]] + } else if (mt == ClassTag.Double) { + (new DoubleHasher).asInstanceOf[Hasher[T]] + } else if (mt == ClassTag.Float) { + (new FloatHasher).asInstanceOf[Hasher[T]] } else { new Hasher[T] } @@ -293,7 +297,7 @@ object OpenHashSet { * A set of specialized hash function implementation to avoid boxing hash code computation * in the specialized implementation of OpenHashSet. */ - sealed class Hasher[@specialized(Long, Int) T] extends Serializable { + sealed class Hasher[@specialized(Long, Int, Double, Float) T] extends Serializable { def hash(o: T): Int = o.hashCode() } @@ -305,6 +309,17 @@ object OpenHashSet { override def hash(o: Int): Int = o } + class DoubleHasher extends Hasher[Double] { + override def hash(o: Double): Int = { + val bits = java.lang.Double.doubleToLongBits(o) + (bits ^ (bits >>> 32)).toInt + } + } + + class FloatHasher extends Hasher[Float] { + override def hash(o: Float): Int = java.lang.Float.floatToIntBits(o) + } + private def grow1(newSize: Int) {} private def move1(oldPos: Int, newPos: Int) { } From 3639b5b1b460c19b7c3a787d118a0bcf35cb8e23 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sun, 29 Jul 2018 06:24:02 +0100 Subject: [PATCH 24/29] optimize nullchecks in generated code --- .../expressions/collectionOperations.scala | 61 +++++++++++++------ 1 file changed, 43 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 504eb29156533..5b320e19d1176 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -4136,26 +4136,56 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike val arrays = ctx.freshName("arrays") val array = ctx.freshName("array") val arrayDataIdx = ctx.freshName("arrayDataIdx") + + val array2NullCheck = if (right.dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($array2.isNullAt($i)) { + | $notFoundNullElement = false; + |} else + """.stripMargin + } else { + "" + } + val array1NullCheck = if (left.dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($array1.isNullAt($i)) { + | if ($notFoundNullElement) { + | $size++; + | $notFoundNullElement = false; + | } + |} else + """.stripMargin + } else { + "" + } + val array1NullAssignment = if (left.dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($array1.isNullAt($i)) { + | if ($notFoundNullElement) { + | ${ev.value}.setNullAt($pos++); + | $notFoundNullElement = false; + | } + |} else + """.stripMargin + } else { + "" + } + s""" |$openHashSet $hs = new $openHashSet$postFix($classTag); |boolean $notFoundNullElement = true; |int $size = 0; |for (int $i = 0; $i < $array2.numElements(); $i++) { - | if ($array2.isNullAt($i)) { - | $notFoundNullElement = false; - | } else { + | $array2NullCheck + | { | $javaTypeName $value = $array2.$getter; | $hsJavaTypeName $hsValue = $genHsValue; | $hs.add$postFix($hsValue); | } |} |for (int $i = 0; $i < $array1.numElements(); $i++) { - | if ($array1.isNullAt($i)) { - | if ($notFoundNullElement) { - | $size++; - | $notFoundNullElement = false; - | } - | } else { + | $array1NullCheck + | { | $javaTypeName $value = $array1.$getter; | $hsJavaTypeName $hsValue = $genHsValue; | if (!$hs.contains($hsValue)) { @@ -4169,21 +4199,16 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike |$notFoundNullElement = true; |int $pos = 0; |for (int $i = 0; $i < $array2.numElements(); $i++) { - | if ($array2.isNullAt($i)) { - | $notFoundNullElement = false; - | } else { + | $array2NullCheck + | { | $javaTypeName $value = $array2.$getter; | $hsJavaTypeName $hsValue = $genHsValue; | $hs.add$postFix($hsValue); | } |} |for (int $i = 0; $i < $array1.numElements(); $i++) { - | if ($array1.isNullAt($i)) { - | if ($notFoundNullElement) { - | ${ev.value}.setNullAt($pos++); - | $notFoundNullElement = false; - | } - | } else { + | $array1NullAssignment + | { | $javaTypeName $value = $array1.$getter; | $hsJavaTypeName $hsValue = $genHsValue; | if (!$hs.contains($hsValue)) { From 4d943c842548914ab151a7a15fc9e0f8743f0caf Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 31 Jul 2018 04:23:35 +0100 Subject: [PATCH 25/29] address review comment --- .../expressions/collectionOperations.scala | 179 ++++++++++-------- 1 file changed, 100 insertions(+), 79 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 5b320e19d1176..8af61f122ad98 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -4082,26 +4082,25 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike val hsValue = ctx.freshName("hsValue") val size = ctx.freshName("size") val (postFix, openHashElementType, hsJavaTypeName, genHsValue, - getter, setter, javaTypeName, arrayBuilder) = + getter, setter, javaTypeName, primitiveTypeName, arrayDataBuilder) = if (elementTypeSupportEquals) { + val ptName = CodeGenerator.primitiveTypeName(elementType) + val unsafeArray = ctx.freshName("unsafeArray") elementType match { case BooleanType | ByteType | ShortType | IntegerType => - val ptName = CodeGenerator.primitiveTypeName(elementType) - val unsafeArray = ctx.freshName("unsafeArray") ("$mcI$sp", "Int", "int", if (elementType != BooleanType) { s"(int) $value" } else { s"$value ? 1 : 0;" }, - s"get$ptName($i)", s"set$ptName($pos, $value)", CodeGenerator.javaType(elementType), + s"get$ptName($i)", s"set$ptName($pos, $value)", + CodeGenerator.javaType(elementType), ptName, s""" |${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")} |${ev.value} = $unsafeArray; """.stripMargin) case LongType | FloatType | DoubleType => - val ptName = CodeGenerator.primitiveTypeName(elementType) - val unsafeArray = ctx.freshName("unsafeArray") val signature = elementType match { case LongType => "$mcJ$sp" case FloatType => "$mcF$sp" @@ -4109,115 +4108,137 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike } (signature, CodeGenerator.boxedType(elementType), CodeGenerator.javaType(elementType), value, - s"get$ptName($i)", s"set$ptName($pos, $value)", CodeGenerator.javaType(elementType), + s"get$ptName($i)", s"set$ptName($pos, $value)", + CodeGenerator.javaType(elementType), ptName, s""" |${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")} |${ev.value} = $unsafeArray; """.stripMargin) case _ => val genericArrayData = classOf[GenericArrayData].getName - val et = ctx.addReferenceObj("elementType", elementType) ("", "Object", "Object", value, - s"get($i, $et)", s"update($pos, $value)", "Object", + s"get($i, null)", s"update($pos, $value)", "Object", "Ref", s"${ev.value} = new $genericArrayData(new Object[$size]);") } } else { - ("", "", "", "", "", "", "", "") + ("", "", "", "", "", "", "", "", "") } nullSafeCodeGen(ctx, ev, (array1, array2) => { if (openHashElementType != "") { // Here, we ensure elementTypeSupportEquals is true val notFoundNullElement = ctx.freshName("notFoundNullElement") + val nullElementIndex = ctx.freshName("nullElementIndex") + val builder = ctx.freshName("builder") + val array = ctx.freshName("array") val openHashSet = classOf[OpenHashSet[_]].getName val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$openHashElementType()" val hs = ctx.freshName("hs") - val arrayData = classOf[ArrayData].getName - val arrays = ctx.freshName("arrays") - val array = ctx.freshName("array") - val arrayDataIdx = ctx.freshName("arrayDataIdx") + val genericArrayData = classOf[GenericArrayData].getName + val arrayBuilder = "scala.collection.mutable.ArrayBuilder" + val arrayBuilderClass = s"$arrayBuilder$$of$primitiveTypeName" + val arrayBuilderClassTag = if (primitiveTypeName != "Ref") { + s"scala.reflect.ClassTag$$.MODULE$$.$primitiveTypeName()" + } else { + s"scala.reflect.ClassTag$$.MODULE$$.AnyRef()" + } - val array2NullCheck = if (right.dataType.asInstanceOf[ArrayType].containsNull) { + def withArray2NullCheck(body: String) = + if (right.dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($array2.isNullAt($i)) { + | $notFoundNullElement = false; + |} else { + | $body + |} + """.stripMargin + } else { + body + } + val array2Body = s""" - |if ($array2.isNullAt($i)) { - | $notFoundNullElement = false; - |} else + |$javaTypeName $value = $array2.$getter; + |$hsJavaTypeName $hsValue = $genHsValue; + |$hs.add$postFix($hsValue); """.stripMargin - } else { - "" - } - val array1NullCheck = if (left.dataType.asInstanceOf[ArrayType].containsNull) { + + def withArray1NullAssignment(body: String) = + if (left.dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($array1.isNullAt($i)) { + | if ($notFoundNullElement) { + | $nullElementIndex = $size; + | $notFoundNullElement = false; + | $size++; + | } + |} else { + | $body + |} + """.stripMargin + } else { + body + } + val array1Body = s""" - |if ($array1.isNullAt($i)) { - | if ($notFoundNullElement) { - | $size++; - | $notFoundNullElement = false; - | } - |} else + |$javaTypeName $value = $array1.$getter; + |$hsJavaTypeName $hsValue = $genHsValue; + |if (!$hs.contains($hsValue)) { + | $hs.add$postFix($hsValue); + | $builder.$$plus$$eq($value); + | $size++; + |} """.stripMargin - } else { - "" - } - val array1NullAssignment = if (left.dataType.asInstanceOf[ArrayType].containsNull) { + + val nonNullArrayDataBuild = if (postFix != "") { s""" - |if ($array1.isNullAt($i)) { - | if ($notFoundNullElement) { - | ${ev.value}.setNullAt($pos++); - | $notFoundNullElement = false; - | } - |} else + |if (!UnsafeArrayData.shouldUseGenericArrayData(${elementType.defaultSize}, $size)) { + | ${ev.value} = UnsafeArrayData.fromPrimitiveArray($builder.result()); + |} else { + | ${ev.value} = new $genericArrayData($builder.result()); + |} """.stripMargin } else { - "" + s"${ev.value} = new $genericArrayData($builder.result());" } + def buildResultArrayData(nonNullArrayDataBuild: String) = + if (dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($nullElementIndex < 0) { + | // result has no null element + | $nonNullArrayDataBuild + |} else { + | // result has null element + | $arrayDataBuilder + | $javaTypeName[] $array = $builder.result(); + | for (int $i = 0, $pos = 0; $pos < $size; $pos++) { + | if ($pos == $nullElementIndex) { + | ${ev.value}.setNullAt($pos); + | } else { + | $javaTypeName $value = $array[$i++]; + | ${ev.value}.$setter; + | } + | } + |} + """.stripMargin + } else { + nonNullArrayDataBuild + } + s""" |$openHashSet $hs = new $openHashSet$postFix($classTag); |boolean $notFoundNullElement = true; - |int $size = 0; |for (int $i = 0; $i < $array2.numElements(); $i++) { - | $array2NullCheck - | { - | $javaTypeName $value = $array2.$getter; - | $hsJavaTypeName $hsValue = $genHsValue; - | $hs.add$postFix($hsValue); - | } - |} - |for (int $i = 0; $i < $array1.numElements(); $i++) { - | $array1NullCheck - | { - | $javaTypeName $value = $array1.$getter; - | $hsJavaTypeName $hsValue = $genHsValue; - | if (!$hs.contains($hsValue)) { - | $hs.add$postFix($hsValue); - | $size++; - | } - | } - |} - |$arrayBuilder - |$hs = new $openHashSet$postFix($classTag); - |$notFoundNullElement = true; - |int $pos = 0; - |for (int $i = 0; $i < $array2.numElements(); $i++) { - | $array2NullCheck - | { - | $javaTypeName $value = $array2.$getter; - | $hsJavaTypeName $hsValue = $genHsValue; - | $hs.add$postFix($hsValue); - | } + | ${withArray2NullCheck(array2Body)} |} + |$arrayBuilderClass $builder = + | ($arrayBuilderClass)$arrayBuilder.make($arrayBuilderClassTag); + |int $nullElementIndex = -1; + |int $size = 0; |for (int $i = 0; $i < $array1.numElements(); $i++) { - | $array1NullAssignment - | { - | $javaTypeName $value = $array1.$getter; - | $hsJavaTypeName $hsValue = $genHsValue; - | if (!$hs.contains($hsValue)) { - | $hs.add$postFix($hsValue); - | ${ev.value}.$setter; - | $pos++; - | } - | } + | ${withArray1NullAssignment(array1Body)} |} + |${buildResultArrayData(nonNullArrayDataBuild)} """.stripMargin } else { val expr = ctx.addReferenceObj("arrayExceptExpr", this) From 49b5ab371af9783be8f2d6351cf664a769957a4e Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 31 Jul 2018 04:49:01 +0100 Subject: [PATCH 26/29] address review comment --- .../expressions/collectionOperations.scala | 27 +++++++++---------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 8af61f122ad98..b89469e2746ed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -4081,11 +4081,11 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike val value = ctx.freshName("value") val hsValue = ctx.freshName("hsValue") val size = ctx.freshName("size") - val (postFix, openHashElementType, hsJavaTypeName, genHsValue, - getter, setter, javaTypeName, primitiveTypeName, arrayDataBuilder) = - if (elementTypeSupportEquals) { - val ptName = CodeGenerator.primitiveTypeName(elementType) - val unsafeArray = ctx.freshName("unsafeArray") + if (elementTypeSupportEquals) { + val ptName = CodeGenerator.primitiveTypeName(elementType) + val unsafeArray = ctx.freshName("unsafeArray") + val (postFix, openHashElementType, hsJavaTypeName, genHsValue, + getter, setter, javaTypeName, primitiveTypeName, arrayDataBuilder) = elementType match { case BooleanType | ByteType | ShortType | IntegerType => ("$mcI$sp", "Int", "int", @@ -4120,13 +4120,8 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike s"get($i, null)", s"update($pos, $value)", "Object", "Ref", s"${ev.value} = new $genericArrayData(new Object[$size]);") } - } else { - ("", "", "", "", "", "", "", "", "") - } - nullSafeCodeGen(ctx, ev, (array1, array2) => { - if (openHashElementType != "") { - // Here, we ensure elementTypeSupportEquals is true + nullSafeCodeGen(ctx, ev, (array1, array2) => { val notFoundNullElement = ctx.freshName("notFoundNullElement") val nullElementIndex = ctx.freshName("nullElementIndex") val builder = ctx.freshName("builder") @@ -4184,7 +4179,7 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike |$hsJavaTypeName $hsValue = $genHsValue; |if (!$hs.contains($hsValue)) { | $hs.add$postFix($hsValue); - | $builder.$$plus$$eq($value); + | $builder.$$plus$$eq($value); | $size++; |} """.stripMargin @@ -4240,11 +4235,13 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike |} |${buildResultArrayData(nonNullArrayDataBuild)} """.stripMargin - } else { + }) + } else { + nullSafeCodeGen(ctx, ev, (array1, array2) => { val expr = ctx.addReferenceObj("arrayExceptExpr", this) s"${ev.value} = ($arrayData)$expr.nullSafeEval($array1, $array2);" - } - }) + }) + } } override def prettyName: String = "array_except" From 282445cb20fda8e963ceb23d5f52d88c7f8d7d84 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 31 Jul 2018 08:03:05 +0100 Subject: [PATCH 27/29] review address comment --- .../sql/catalyst/expressions/collectionOperations.scala | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index b89469e2746ed..997e9234d8fa8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -4087,13 +4087,8 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike val (postFix, openHashElementType, hsJavaTypeName, genHsValue, getter, setter, javaTypeName, primitiveTypeName, arrayDataBuilder) = elementType match { - case BooleanType | ByteType | ShortType | IntegerType => - ("$mcI$sp", "Int", "int", - if (elementType != BooleanType) { - s"(int) $value" - } else { - s"$value ? 1 : 0;" - }, + case ByteType | ShortType | IntegerType => + ("$mcI$sp", "Int", "int", s"(int) $value", s"get$ptName($i)", s"set$ptName($pos, $value)", CodeGenerator.javaType(elementType), ptName, s""" From 72ef66479842df6fe793fad39dfcb0f7e494b0fd Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 31 Jul 2018 15:16:16 +0100 Subject: [PATCH 28/29] fix pyspark test failure --- .../spark/sql/catalyst/expressions/collectionOperations.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 997e9234d8fa8..d8d04ed148055 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -4111,8 +4111,9 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike """.stripMargin) case _ => val genericArrayData = classOf[GenericArrayData].getName + val et = ctx.addReferenceObj("elementType", elementType) ("", "Object", "Object", value, - s"get($i, null)", s"update($pos, $value)", "Object", "Ref", + s"get($i, $et)", s"update($pos, $value)", "Object", "Ref", s"${ev.value} = new $genericArrayData(new Object[$size]);") } From 93e7979a1c3fb82c47ecae5b3ed539b31cb99e19 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 31 Jul 2018 19:21:38 +0100 Subject: [PATCH 29/29] address array overflow --- .../expressions/collectionOperations.scala | 32 +++++++++++++------ 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index d8d04ed148055..b03bd7d942d72 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -4174,22 +4174,36 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike |$javaTypeName $value = $array1.$getter; |$hsJavaTypeName $hsValue = $genHsValue; |if (!$hs.contains($hsValue)) { + | if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | break; + | } | $hs.add$postFix($hsValue); - | $builder.$$plus$$eq($value); - | $size++; + | $builder.$$plus$$eq($value); |} """.stripMargin - val nonNullArrayDataBuild = if (postFix != "") { + val nonNullArrayDataBuild = { + val build = if (postFix != "") { + val defaultSize = elementType.defaultSize + s""" + |if (!UnsafeArrayData.shouldUseGenericArrayData($defaultSize, $size)) { + | ${ev.value} = UnsafeArrayData.fromPrimitiveArray($builder.result()); + |} else { + | ${ev.value} = new $genericArrayData($builder.result()); + |} + """.stripMargin + } else { + s"${ev.value} = new $genericArrayData($builder.result());" + } s""" - |if (!UnsafeArrayData.shouldUseGenericArrayData(${elementType.defaultSize}, $size)) { - | ${ev.value} = UnsafeArrayData.fromPrimitiveArray($builder.result()); - |} else { - | ${ev.value} = new $genericArrayData($builder.result()); + |if ($size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | throw new RuntimeException("Unsuccessful try create array with " + $size + + | " bytes of data due to exceeding the limit " + + | "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH} elements for GenericArrayData." + + | " $prettyName failed."); |} + |$build """.stripMargin - } else { - s"${ev.value} = new $genericArrayData($builder.result());" } def buildResultArrayData(nonNullArrayDataBuild: String) =