From 755d6db0cbec1b822e019c425aca258eb7b512a1 Mon Sep 17 00:00:00 2001 From: mn-mikke Date: Mon, 26 Mar 2018 10:40:51 +0200 Subject: [PATCH 1/8] [SPARK-23821][SQL] Collection function: flatten --- python/pyspark/sql/functions.py | 17 ++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 164 ++++++++++++++++++ .../CollectionExpressionsSuite.scala | 37 ++++ .../org/apache/spark/sql/functions.scala | 8 + .../spark/sql/DataFrameFunctionsSuite.scala | 45 +++++ 6 files changed, 272 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index ad3e37c87262..596cb3f33f7f 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2085,6 +2085,23 @@ def sort_array(col, asc=True): return Column(sc._jvm.functions.sort_array(_to_java_column(col), asc)) +@since(2.4) +def flatten(col): + """ + Collection function: creates a single array from an array of arrays. + If a structure of nested arrays is deeper than two levels, + only one level of nesting is removed. + + :param col: name of column or expression + + >>> df = spark.createDataFrame([([[1, 2, 3], [4, 5], [6]],),([None, [4, 5]],)], ['data']) + >>> df.select(flatten(df.data).alias('r')).collect() + [Row(r=[1, 2, 3, 4, 5, 6]), Row(r=None)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.flatten(_to_java_column(col))) + + @since(2.3) def map_keys(col): """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 747016beb06e..0cdcb19347d9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -408,6 +408,7 @@ object FunctionRegistry { expression[MapValues]("map_values"), expression[Size]("size"), expression[SortArray]("sort_array"), + expression[Flatten]("flatten"), CreateStruct.registryEntry, // misc functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index beb84694c44e..9be22a0f4246 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -23,6 +23,8 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, CodegenFallback, ExprCode} import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.Platform +import org.apache.spark.unsafe.array.ByteArrayMethods /** * Given an array or map, returns its size. Returns -1 if null. @@ -287,3 +289,165 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } + +/** + * Transforms an array of arrays into a single array. + */ +@ExpressionDescription( + usage = "_FUNC_(arrayOfArrays) - Transforms an array of arrays into a single array.", + examples = """ + Examples: + > SELECT _FUNC_(array(array(1, 2), array(3, 4)); + [1,2,3,4] + """) +case class Flatten(child: Expression) extends UnaryExpression { + + override def nullable: Boolean = child.nullable || dataType.containsNull + + override def dataType: ArrayType = { + child + .dataType.asInstanceOf[ArrayType] + .elementType.asInstanceOf[ArrayType] + } + + override def checkInputDataTypes(): TypeCheckResult = { + if ( + ArrayType.acceptsType(child.dataType) && + ArrayType.acceptsType(child.dataType.asInstanceOf[ArrayType].elementType) + ) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure( + s"The argument should be an array of arrays, " + + s"but '${child.sql}' is of ${child.dataType.simpleString} type." + ) + } + } + + override def nullSafeEval(array: Any): Any = { + val elements = array.asInstanceOf[ArrayData].toObjectArray(dataType) + + if (elements.contains(null)) { + null + } else { + val flattened = elements.flatMap( + _.asInstanceOf[ArrayData].toObjectArray(dataType.elementType) + ) + new GenericArrayData(flattened) + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, c => { + val code = + if (CodeGenerator.isPrimitiveType(dataType.elementType)) { + genCodeForConcatOfPrimitiveElements(ctx, c, ev.value) + } else { + genCodeForConcatOfComplexElements(ctx, c, ev.value) + } + nullElementsProtection(ev, c, code) + }) + } + + private def nullElementsProtection( + ev: ExprCode, + childVariableName: String, + coreLogic: String): String = { + s""" + |for(int z=0; z < $childVariableName.numElements(); z++) { + | ${ev.isNull} |= $childVariableName.isNullAt(z); + |} + |if(!${ev.isNull}) { + | $coreLogic + |} + """.stripMargin + } + + private def genCodeForNumberOfElements( + ctx: CodegenContext, + childVariableName: String) : (String, String) = { + val variableName = ctx.freshName("numElements") + val code = + s""" + |int $variableName = 0; + |for(int z=0; z < $childVariableName.numElements(); z++) { + | $variableName += $childVariableName.getArray(z).numElements(); + |} + """.stripMargin + (code, variableName) + } + + private def genCodeForConcatOfPrimitiveElements( + ctx: CodegenContext, + childVariableName: String, + arrayDataName: String): String = { + val arrayName = ctx.freshName("array") + val arraySizeName = ctx.freshName("size") + val counter = ctx.freshName("counter") + val tempArrayDataName = ctx.freshName("tempArrayData") + + val elementType = dataType.elementType + val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, childVariableName) + + val unsafeArraySizeInBytes = s""" + |int $arraySizeName = UnsafeArrayData.calculateHeaderPortionInBytes($numElemName) + + |${classOf[ByteArrayMethods].getName}.roundNumberOfBytesToNearestWord( + | ${elementType.defaultSize} * $numElemName + |); + """.stripMargin + val baseOffset = Platform.BYTE_ARRAY_OFFSET + + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + + s""" + |$numElemCode + |$unsafeArraySizeInBytes + |byte[] $arrayName = new byte[$arraySizeName]; + |UnsafeArrayData $tempArrayDataName = new UnsafeArrayData(); + |Platform.putLong($arrayName, $baseOffset, $numElemName); + |$tempArrayDataName.pointTo($arrayName, $baseOffset, $arraySizeName); + |int $counter = 0; + |for(int k=0; k < $childVariableName.numElements(); k++) { + | ArrayData arr = $childVariableName.getArray(k); + | for(int l = 0; l < arr.numElements(); l++) { + | if(arr.isNullAt(l)) { + | $tempArrayDataName.setNullAt($counter); + | } else { + | $tempArrayDataName.set$primitiveValueTypeName( + | $counter, + | arr.get$primitiveValueTypeName(l) + | ); + | } + | $counter++; + | } + |} + |$arrayDataName = $tempArrayDataName; + """.stripMargin + } + + private def genCodeForConcatOfComplexElements( + ctx: CodegenContext, + childVariableName: String, + arrayDataName: String): String = { + val genericArrayClass = classOf[GenericArrayData].getName + val arrayName = ctx.freshName("arrayObject") + val counter = ctx.freshName("counter") + val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, childVariableName) + + s""" + |$numElemCode + |Object[] $arrayName = new Object[$numElemName]; + |int $counter = 0; + |for(int k=0; k < $childVariableName.numElements(); k++) { + | Object[] arr = $childVariableName.getArray(k).array(); + | for(int l = 0; l < arr.length; l++) { + | $arrayName[$counter] = arr[l]; + | $counter++; + | } + |} + |$arrayDataName = new $genericArrayClass($arrayName); + """.stripMargin + } + + override def prettyName: String = "flatten" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 020687e4b3a2..4d66c9cde361 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -105,4 +105,41 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayContains(a3, Literal("")), null) checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) } + + test("Flatten") { + val intType = ArrayType(ArrayType(IntegerType)) + val ai0 = Literal.create(Seq(Seq(1, 2, 3), Seq(4, 5), Seq(6)), intType) + val ai1 = Literal.create(Seq(Seq(1, 2, 3), Seq.empty, Seq(6)), intType) + val ai2 = Literal.create(Seq(Seq(null, null, null), Seq(4, null), Seq(6)), intType) + val ai3 = Literal.create(Seq(null, Seq(4, null), Seq(6)), intType) + val ai4 = Literal.create(Seq(Seq(1)), intType) + val ai5 = Literal.create(Seq(Seq.empty), intType) + val ai6 = Literal.create(Seq.empty, intType) + + checkEvaluation(Flatten(ai0), Seq(1, 2, 3, 4, 5, 6)) + checkEvaluation(Flatten(ai1), Seq(1, 2, 3, 6)) + checkEvaluation(Flatten(ai2), Seq(null, null, null, 4, null, 6)) + checkEvaluation(Flatten(ai3), null) + checkEvaluation(Flatten(ai4), Seq(1)) + checkEvaluation(Flatten(ai5), Seq.empty) + checkEvaluation(Flatten(ai6), Seq.empty) + + val strType = ArrayType(ArrayType(StringType)) + val as0 = Literal.create(Seq(Seq("a"), Seq("b", "c"), Seq("d", "e", "f")), strType) + val as1 = Literal.create(Seq(Seq.empty, Seq("a", "b"), Seq.empty), strType) + val as2 = Literal.create(Seq(Seq(null, null), Seq("a", null), Seq(null)), strType) + val as3 = Literal.create(Seq(Seq("a"), null), strType) + val as4 = Literal.create(Seq(Seq("a")), strType) + val as5 = Literal.create(Seq(Seq.empty), strType) + val as6 = Literal.create(Seq.empty, strType) + + checkEvaluation(Flatten(as0), Seq("a", "b", "c", "d", "e", "f")) + checkEvaluation(Flatten(as1), Seq("a", "b")) + checkEvaluation(Flatten(as2), Seq(null, null, "a", null, null)) + checkEvaluation(Flatten(as3), null) + checkEvaluation(Flatten(as4), Seq("a")) + checkEvaluation(Flatten(as5), Seq.empty) + checkEvaluation(Flatten(as6), Seq.empty) + + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index c9ca9a899634..db9bf5e59663 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3300,6 +3300,14 @@ object functions { */ def sort_array(e: Column, asc: Boolean): Column = withExpr { SortArray(e.expr, lit(asc).expr) } + /** + * Creates a single array from an array of arrays. If a structure of nested arrays is deeper than + * two levels, only one level of nesting is removed. + * @group collection_funcs + * @since 2.4.0 + */ + def flatten(e: Column): Column = withExpr{ Flatten(e.expr) } + /** * Returns an unordered array containing the keys of the map. * @group collection_funcs diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 50e475984f45..86acc21296ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -413,6 +413,51 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("flatten function") { + val df = Seq( + (Seq(Seq(1, 2, 3), Seq(4, 5), Seq(6)), Seq(Seq("a", "b"), Seq("c"))), + (Seq(Seq(1), Seq.empty, Seq(2)), Seq(Seq(null), Seq(null, "a"))), + (Seq(Seq(2), null, Seq(1)), Seq(Seq("a"), null)) + ).toDF("i", "s") + val edf = Seq((1, "a", Seq(1, 2, 3))).toDF("i", "s", "arr") + + // Simple test cases + checkAnswer( + df.select(flatten($"i")), + Seq(Row(Seq(1, 2, 3, 4, 5, 6)), Row(Seq(1, 2)), Row(null)) + ) + checkAnswer( + df.selectExpr("flatten(i)"), + Seq(Row(Seq(1, 2, 3, 4, 5, 6)), Row(Seq(1, 2)), Row(null)) + ) + checkAnswer( + edf.selectExpr("flatten(array(arr, array(null, 5)))"), + Seq(Row(Seq(1, 2, 3, null, 5))) + ) + checkAnswer( + df.select(flatten($"s")), + Seq(Row(Seq("a", "b", "c")), Row(Seq(null, null, "a")), Row(null)) + ) + checkAnswer( + df.selectExpr("flatten(s)"), + Seq(Row(Seq("a", "b", "c")), Row(Seq(null, null, "a")), Row(null)) + ) + + // Error test cases + intercept[AnalysisException] { + edf.select(flatten($"arr")) + } + intercept[AnalysisException] { + edf.select(flatten($"i")) + } + intercept[AnalysisException] { + edf.select(flatten($"s")) + } + intercept[AnalysisException] { + edf.selectExpr("flatten(null)") + } + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { From ad469628d42054f3ccab5622e145c82301825ce7 Mon Sep 17 00:00:00 2001 From: mn-mikke Date: Tue, 3 Apr 2018 00:15:58 +0200 Subject: [PATCH 2/8] [SPARK-23821][SQL] Improving test cases --- python/pyspark/sql/functions.py | 2 +- .../CollectionExpressionsSuite.scala | 125 +++++++++++++----- .../spark/sql/DataFrameFunctionsSuite.scala | 97 +++++++++----- 3 files changed, 159 insertions(+), 65 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 596cb3f33f7f..748a961c3e0c 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2094,7 +2094,7 @@ def flatten(col): :param col: name of column or expression - >>> df = spark.createDataFrame([([[1, 2, 3], [4, 5], [6]],),([None, [4, 5]],)], ['data']) + >>> df = spark.createDataFrame([([[1, 2, 3], [4, 5], [6]],), ([None, [4, 5]],)], ['data']) >>> df.select(flatten(df.data).alias('r')).collect() [Row(r=[1, 2, 3, 4, 5, 6]), Row(r=None)] """ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 4d66c9cde361..4e7b244268e1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -107,39 +107,98 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper } test("Flatten") { - val intType = ArrayType(ArrayType(IntegerType)) - val ai0 = Literal.create(Seq(Seq(1, 2, 3), Seq(4, 5), Seq(6)), intType) - val ai1 = Literal.create(Seq(Seq(1, 2, 3), Seq.empty, Seq(6)), intType) - val ai2 = Literal.create(Seq(Seq(null, null, null), Seq(4, null), Seq(6)), intType) - val ai3 = Literal.create(Seq(null, Seq(4, null), Seq(6)), intType) - val ai4 = Literal.create(Seq(Seq(1)), intType) - val ai5 = Literal.create(Seq(Seq.empty), intType) - val ai6 = Literal.create(Seq.empty, intType) - - checkEvaluation(Flatten(ai0), Seq(1, 2, 3, 4, 5, 6)) - checkEvaluation(Flatten(ai1), Seq(1, 2, 3, 6)) - checkEvaluation(Flatten(ai2), Seq(null, null, null, 4, null, 6)) - checkEvaluation(Flatten(ai3), null) - checkEvaluation(Flatten(ai4), Seq(1)) - checkEvaluation(Flatten(ai5), Seq.empty) - checkEvaluation(Flatten(ai6), Seq.empty) - - val strType = ArrayType(ArrayType(StringType)) - val as0 = Literal.create(Seq(Seq("a"), Seq("b", "c"), Seq("d", "e", "f")), strType) - val as1 = Literal.create(Seq(Seq.empty, Seq("a", "b"), Seq.empty), strType) - val as2 = Literal.create(Seq(Seq(null, null), Seq("a", null), Seq(null)), strType) - val as3 = Literal.create(Seq(Seq("a"), null), strType) - val as4 = Literal.create(Seq(Seq("a")), strType) - val as5 = Literal.create(Seq(Seq.empty), strType) - val as6 = Literal.create(Seq.empty, strType) - - checkEvaluation(Flatten(as0), Seq("a", "b", "c", "d", "e", "f")) - checkEvaluation(Flatten(as1), Seq("a", "b")) - checkEvaluation(Flatten(as2), Seq(null, null, "a", null, null)) - checkEvaluation(Flatten(as3), null) - checkEvaluation(Flatten(as4), Seq("a")) - checkEvaluation(Flatten(as5), Seq.empty) - checkEvaluation(Flatten(as6), Seq.empty) + // Primitive-type test cases + val intArrayType = ArrayType(ArrayType(IntegerType)) + + // Main test cases (primitive type) + val aim1 = Literal.create(Seq(Seq(1, 2, 3), Seq(4, 5), Seq(6)), intArrayType) + val aim2 = Literal.create(Seq(Seq(1, 2, 3)), intArrayType) + + checkEvaluation(Flatten(aim1), Seq(1, 2, 3, 4, 5, 6)) + checkEvaluation(Flatten(aim2), Seq(1, 2, 3)) + + // Test cases with an empty array (primitive type) + val aie1 = Literal.create(Seq(Seq.empty, Seq(1, 2), Seq(3, 4)), intArrayType) + val aie2 = Literal.create(Seq(Seq(1, 2), Seq.empty, Seq(3, 4)), intArrayType) + val aie3 = Literal.create(Seq(Seq(1, 2), Seq(3, 4), Seq.empty), intArrayType) + val aie4 = Literal.create(Seq(Seq.empty, Seq.empty, Seq.empty), intArrayType) + val aie5 = Literal.create(Seq(Seq.empty), intArrayType) + val aie6 = Literal.create(Seq.empty, intArrayType) + + checkEvaluation(Flatten(aie1), Seq(1, 2, 3, 4)) + checkEvaluation(Flatten(aie2), Seq(1, 2, 3, 4)) + checkEvaluation(Flatten(aie3), Seq(1, 2, 3, 4)) + checkEvaluation(Flatten(aie4), Seq.empty) + checkEvaluation(Flatten(aie5), Seq.empty) + checkEvaluation(Flatten(aie6), Seq.empty) + + // Test cases with null elements (primitive type) + val ain1 = Literal.create(Seq(Seq(null, null, null), Seq(4, null)), intArrayType) + val ain2 = Literal.create(Seq(Seq(null, 2, null), Seq(null, null)), intArrayType) + val ain3 = Literal.create(Seq(Seq(null, null), Seq(null, null)), intArrayType) + + checkEvaluation(Flatten(ain1), Seq(null, null, null, 4, null)) + checkEvaluation(Flatten(ain2), Seq(null, 2, null, null, null)) + checkEvaluation(Flatten(ain3), Seq(null, null, null, null)) + + // Test cases with a null array (primitive type) + val aia1 = Literal.create(Seq(null, Seq(1, 2)), intArrayType) + val aia2 = Literal.create(Seq(Seq(1, 2), null), intArrayType) + val aia3 = Literal.create(Seq(null), intArrayType) + val aia4 = Literal.create(null, intArrayType) + + checkEvaluation(Flatten(aia1), null) + checkEvaluation(Flatten(aia2), null) + checkEvaluation(Flatten(aia3), null) + checkEvaluation(Flatten(aia4), null) + + // Complex-type test cases + val strArrayType = ArrayType(ArrayType(StringType)) + val arrArrayType = ArrayType(ArrayType(ArrayType(StringType))) + + // Main test cases (complex type) + val asm1 = Literal.create(Seq(Seq("a"), Seq("b", "c"), Seq("d", "e", "f")), strArrayType) + val asm2 = Literal.create(Seq(Seq("a", "b")), strArrayType) + val asm3 = Literal.create(Seq(Seq(Seq("a", "b"), Seq("c")), Seq(Seq("d", "e"))), arrArrayType) + + checkEvaluation(Flatten(asm1), Seq("a", "b", "c", "d", "e", "f")) + checkEvaluation(Flatten(asm2), Seq("a", "b")) + checkEvaluation(Flatten(asm3), Seq(Seq("a", "b"), Seq("c"), Seq("d", "e"))) + + // Test cases with an empty array (complex type) + val ase1 = Literal.create(Seq(Seq.empty, Seq("a", "b"), Seq("c", "d")), strArrayType) + val ase2 = Literal.create(Seq(Seq("a", "b"), Seq.empty, Seq("c", "d")), strArrayType) + val ase3 = Literal.create(Seq(Seq("a", "b"), Seq("c", "d"), Seq.empty), strArrayType) + val ase4 = Literal.create(Seq(Seq.empty, Seq.empty, Seq.empty), strArrayType) + val ase5 = Literal.create(Seq(Seq.empty), strArrayType) + val ase6 = Literal.create(Seq.empty, strArrayType) + + checkEvaluation(Flatten(ase1), Seq("a", "b", "c", "d")) + checkEvaluation(Flatten(ase2), Seq("a", "b", "c", "d")) + checkEvaluation(Flatten(ase3), Seq("a", "b", "c", "d")) + checkEvaluation(Flatten(ase4), Seq.empty) + checkEvaluation(Flatten(ase5), Seq.empty) + checkEvaluation(Flatten(ase6), Seq.empty) + + // Test cases with null elements (complex type) + val asn1 = Literal.create(Seq(Seq(null, null, "c"), Seq(null, null)), strArrayType) + val asn2 = Literal.create(Seq(Seq(null, null, null), Seq("d", null)), strArrayType) + val asn3 = Literal.create(Seq(Seq(null, null), Seq(null, null)), strArrayType) + + checkEvaluation(Flatten(asn1), Seq(null, null, "c", null, null)) + checkEvaluation(Flatten(asn2), Seq(null, null, null, "d", null)) + checkEvaluation(Flatten(asn3), Seq(null, null, null, null)) + + // Test cases with a null array (complex type) + val asa1 = Literal.create(Seq(null, Seq("a", "b")), strArrayType) + val asa2 = Literal.create(Seq(Seq("a", "b"), null), strArrayType) + val asa3 = Literal.create(Seq(null), strArrayType) + val asa4 = Literal.create(null, strArrayType) + + checkEvaluation(Flatten(asa1), null) + checkEvaluation(Flatten(asa2), null) + checkEvaluation(Flatten(asa3), null) + checkEvaluation(Flatten(asa4), null) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 86acc21296ce..dcc2d0f1735f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -414,47 +414,82 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } test("flatten function") { - val df = Seq( - (Seq(Seq(1, 2, 3), Seq(4, 5), Seq(6)), Seq(Seq("a", "b"), Seq("c"))), - (Seq(Seq(1), Seq.empty, Seq(2)), Seq(Seq(null), Seq(null, "a"))), - (Seq(Seq(2), null, Seq(1)), Seq(Seq("a"), null)) - ).toDF("i", "s") - val edf = Seq((1, "a", Seq(1, 2, 3))).toDF("i", "s", "arr") - - // Simple test cases - checkAnswer( - df.select(flatten($"i")), - Seq(Row(Seq(1, 2, 3, 4, 5, 6)), Row(Seq(1, 2)), Row(null)) - ) - checkAnswer( - df.selectExpr("flatten(i)"), - Seq(Row(Seq(1, 2, 3, 4, 5, 6)), Row(Seq(1, 2)), Row(null)) - ) - checkAnswer( - edf.selectExpr("flatten(array(arr, array(null, 5)))"), - Seq(Row(Seq(1, 2, 3, null, 5))) - ) - checkAnswer( - df.select(flatten($"s")), - Seq(Row(Seq("a", "b", "c")), Row(Seq(null, null, "a")), Row(null)) - ) - checkAnswer( - df.selectExpr("flatten(s)"), - Seq(Row(Seq("a", "b", "c")), Row(Seq(null, null, "a")), Row(null)) + val oneRowDF = Seq((1, "a", Seq(1, 2, 3))).toDF("i", "s", "arr") + + // Test cases with a primitive type + val intDF = Seq( + (Seq(Seq(1, 2, 3), Seq(4, 5), Seq(6))), + (Seq(Seq(1, 2))), + (Seq(Seq(1), Seq.empty)), + (Seq(Seq.empty, Seq(1))), + (Seq(Seq.empty, Seq.empty)), + (Seq(Seq(1), null)), + (Seq(null, Seq(1))), + (Seq(null, null)) + ).toDF("i") + + val intDFResult = Seq( + Row(Seq(1, 2, 3, 4, 5, 6)), + Row(Seq(1, 2)), + Row(Seq(1)), + Row(Seq(1)), + Row(Seq.empty), + Row(null), + Row(null), + Row(null) + ) + + checkAnswer(intDF.select(flatten($"i")), intDFResult) + checkAnswer(intDF.selectExpr("flatten(i)"), intDFResult) + checkAnswer( + oneRowDF.selectExpr("flatten(array(arr, array(null, 5), array(6, null)))"), + Seq(Row(Seq(1, 2, 3, null, 5, 6, null))) + ) + + // Test cases with complex types + val strDF = Seq( + (Seq(Seq("a", "b"), Seq("c"), Seq("d", "e", "f"))), + (Seq(Seq("a", "b"))), + (Seq(Seq("a", null), Seq(null, "b"), Seq(null, null))), + (Seq(Seq("a"), Seq.empty)), + (Seq(Seq.empty, Seq("a"))), + (Seq(Seq.empty, Seq.empty)), + (Seq(Seq("a"), null)), + (Seq(null, Seq("a"))), + (Seq(null, null)) + ).toDF("s") + + val strDFResult = Seq( + Row(Seq("a", "b", "c", "d", "e", "f")), + Row(Seq("a", "b")), + Row(Seq("a", null, null, "b", null, null)), + Row(Seq("a")), + Row(Seq("a")), + Row(Seq.empty), + Row(null), + Row(null), + Row(null) + ) + + checkAnswer(strDF.select(flatten($"s")), strDFResult) + checkAnswer(strDF.selectExpr("flatten(s)"), strDFResult) + checkAnswer( + oneRowDF.selectExpr("flatten(array(array(arr, arr), array(arr)))"), + Seq(Row(Seq(Seq(1, 2, 3), Seq(1, 2, 3), Seq(1, 2, 3)))) ) // Error test cases intercept[AnalysisException] { - edf.select(flatten($"arr")) + oneRowDF.select(flatten($"arr")) } intercept[AnalysisException] { - edf.select(flatten($"i")) + oneRowDF.select(flatten($"i")) } intercept[AnalysisException] { - edf.select(flatten($"s")) + oneRowDF.select(flatten($"s")) } intercept[AnalysisException] { - edf.selectExpr("flatten(null)") + oneRowDF.selectExpr("flatten(null)") } } From a50d42ebd26da8474200e03e50f2a891eba235f5 Mon Sep 17 00:00:00 2001 From: mn-mikke Date: Mon, 9 Apr 2018 13:38:23 +0200 Subject: [PATCH 3/8] [SPARK-23821][SQL] Code-styling improvements --- .../expressions/collectionOperations.scala | 111 +++++++++--------- .../org/apache/spark/sql/functions.scala | 2 +- .../spark/sql/DataFrameFunctionsSuite.scala | 12 +- 3 files changed, 58 insertions(+), 67 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 9be22a0f4246..714c4fcb80af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -299,7 +299,8 @@ case class ArrayContains(left: Expression, right: Expression) Examples: > SELECT _FUNC_(array(array(1, 2), array(3, 4)); [1,2,3,4] - """) + """, + since = "2.4.0") case class Flatten(child: Expression) extends UnaryExpression { override def nullable: Boolean = child.nullable || dataType.containsNull @@ -310,18 +311,14 @@ case class Flatten(child: Expression) extends UnaryExpression { .elementType.asInstanceOf[ArrayType] } - override def checkInputDataTypes(): TypeCheckResult = { - if ( - ArrayType.acceptsType(child.dataType) && - ArrayType.acceptsType(child.dataType.asInstanceOf[ArrayType].elementType) - ) { + override def checkInputDataTypes(): TypeCheckResult = child.dataType match { + case ArrayType(_: ArrayType, _) => TypeCheckResult.TypeCheckSuccess - } else { + case _ => TypeCheckResult.TypeCheckFailure( s"The argument should be an array of arrays, " + s"but '${child.sql}' is of ${child.dataType.simpleString} type." ) - } } override def nullSafeEval(array: Any): Any = { @@ -339,8 +336,7 @@ case class Flatten(child: Expression) extends UnaryExpression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, c => { - val code = - if (CodeGenerator.isPrimitiveType(dataType.elementType)) { + val code = if (CodeGenerator.isPrimitiveType(dataType.elementType)) { genCodeForConcatOfPrimitiveElements(ctx, c, ev.value) } else { genCodeForConcatOfComplexElements(ctx, c, ev.value) @@ -354,26 +350,25 @@ case class Flatten(child: Expression) extends UnaryExpression { childVariableName: String, coreLogic: String): String = { s""" - |for(int z=0; z < $childVariableName.numElements(); z++) { - | ${ev.isNull} |= $childVariableName.isNullAt(z); - |} - |if(!${ev.isNull}) { - | $coreLogic - |} - """.stripMargin + |for(int z=0; z < $childVariableName.numElements(); z++) { + | ${ev.isNull} |= $childVariableName.isNullAt(z); + |} + |if(!${ev.isNull}) { + | $coreLogic + |} + """.stripMargin } private def genCodeForNumberOfElements( ctx: CodegenContext, childVariableName: String) : (String, String) = { val variableName = ctx.freshName("numElements") - val code = - s""" - |int $variableName = 0; - |for(int z=0; z < $childVariableName.numElements(); z++) { - | $variableName += $childVariableName.getArray(z).numElements(); - |} - """.stripMargin + val code = s""" + |int $variableName = 0; + |for(int z=0; z < $childVariableName.numElements(); z++) { + | $variableName += $childVariableName.getArray(z).numElements(); + |} + """.stripMargin (code, variableName) } @@ -400,28 +395,28 @@ case class Flatten(child: Expression) extends UnaryExpression { val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) s""" - |$numElemCode - |$unsafeArraySizeInBytes - |byte[] $arrayName = new byte[$arraySizeName]; - |UnsafeArrayData $tempArrayDataName = new UnsafeArrayData(); - |Platform.putLong($arrayName, $baseOffset, $numElemName); - |$tempArrayDataName.pointTo($arrayName, $baseOffset, $arraySizeName); - |int $counter = 0; - |for(int k=0; k < $childVariableName.numElements(); k++) { - | ArrayData arr = $childVariableName.getArray(k); - | for(int l = 0; l < arr.numElements(); l++) { - | if(arr.isNullAt(l)) { - | $tempArrayDataName.setNullAt($counter); - | } else { - | $tempArrayDataName.set$primitiveValueTypeName( - | $counter, - | arr.get$primitiveValueTypeName(l) - | ); - | } - | $counter++; - | } - |} - |$arrayDataName = $tempArrayDataName; + |$numElemCode + |$unsafeArraySizeInBytes + |byte[] $arrayName = new byte[$arraySizeName]; + |UnsafeArrayData $tempArrayDataName = new UnsafeArrayData(); + |Platform.putLong($arrayName, $baseOffset, $numElemName); + |$tempArrayDataName.pointTo($arrayName, $baseOffset, $arraySizeName); + |int $counter = 0; + |for(int k=0; k < $childVariableName.numElements(); k++) { + | ArrayData arr = $childVariableName.getArray(k); + | for(int l = 0; l < arr.numElements(); l++) { + | if(arr.isNullAt(l)) { + | $tempArrayDataName.setNullAt($counter); + | } else { + | $tempArrayDataName.set$primitiveValueTypeName( + | $counter, + | arr.get$primitiveValueTypeName(l) + | ); + | } + | $counter++; + | } + |} + |$arrayDataName = $tempArrayDataName; """.stripMargin } @@ -435,18 +430,18 @@ case class Flatten(child: Expression) extends UnaryExpression { val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, childVariableName) s""" - |$numElemCode - |Object[] $arrayName = new Object[$numElemName]; - |int $counter = 0; - |for(int k=0; k < $childVariableName.numElements(); k++) { - | Object[] arr = $childVariableName.getArray(k).array(); - | for(int l = 0; l < arr.length; l++) { - | $arrayName[$counter] = arr[l]; - | $counter++; - | } - |} - |$arrayDataName = new $genericArrayClass($arrayName); - """.stripMargin + |$numElemCode + |Object[] $arrayName = new Object[$numElemName]; + |int $counter = 0; + |for(int k=0; k < $childVariableName.numElements(); k++) { + | Object[] arr = $childVariableName.getArray(k).array(); + | for(int l = 0; l < arr.length; l++) { + | $arrayName[$counter] = arr[l]; + | $counter++; + | } + |} + |$arrayDataName = new $genericArrayClass($arrayName); + """.stripMargin } override def prettyName: String = "flatten" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index db9bf5e59663..abc351df8793 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3306,7 +3306,7 @@ object functions { * @group collection_funcs * @since 2.4.0 */ - def flatten(e: Column): Column = withExpr{ Flatten(e.expr) } + def flatten(e: Column): Column = withExpr { Flatten(e.expr) } /** * Returns an unordered array containing the keys of the map. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index dcc2d0f1735f..0ceb032cbd45 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -436,15 +436,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(Seq.empty), Row(null), Row(null), - Row(null) - ) + Row(null)) checkAnswer(intDF.select(flatten($"i")), intDFResult) checkAnswer(intDF.selectExpr("flatten(i)"), intDFResult) checkAnswer( oneRowDF.selectExpr("flatten(array(arr, array(null, 5), array(6, null)))"), - Seq(Row(Seq(1, 2, 3, null, 5, 6, null))) - ) + Seq(Row(Seq(1, 2, 3, null, 5, 6, null)))) // Test cases with complex types val strDF = Seq( @@ -468,15 +466,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(Seq.empty), Row(null), Row(null), - Row(null) - ) + Row(null)) checkAnswer(strDF.select(flatten($"s")), strDFResult) checkAnswer(strDF.selectExpr("flatten(s)"), strDFResult) checkAnswer( oneRowDF.selectExpr("flatten(array(array(arr, arr), array(arr)))"), - Seq(Row(Seq(Seq(1, 2, 3), Seq(1, 2, 3), Seq(1, 2, 3)))) - ) + Seq(Row(Seq(Seq(1, 2, 3), Seq(1, 2, 3), Seq(1, 2, 3))))) // Error test cases intercept[AnalysisException] { From 0089e456eb1014c2cf0244df94e8185af19fd66d Mon Sep 17 00:00:00 2001 From: mn-mikke Date: Mon, 16 Apr 2018 13:59:12 +0200 Subject: [PATCH 4/8] [SPARK-23821][SQL] Checks of max array size + Added more tests --- .../spark/unsafe/array/ByteArrayMethods.java | 6 +- .../catalyst/expressions/UnsafeArrayData.java | 12 ++- .../expressions/collectionOperations.scala | 83 ++++++++++++------- .../CollectionExpressionsSuite.scala | 10 +-- .../spark/sql/DataFrameFunctionsSuite.scala | 5 +- 5 files changed, 76 insertions(+), 40 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java index 4bc9955090fd..ef0f78d95d1e 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java @@ -33,7 +33,11 @@ public static long nextPowerOf2(long num) { } public static int roundNumberOfBytesToNearestWord(int numBytes) { - int remainder = numBytes & 0x07; // This is equivalent to `numBytes % 8` + return (int)roundNumberOfBytesToNearestWord((long)numBytes); + } + + public static long roundNumberOfBytesToNearestWord(long numBytes) { + long remainder = numBytes & 0x07; // This is equivalent to `numBytes % 8` if (remainder == 0) { return numBytes; } else { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index 8546c2833553..eb1b541a13dd 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -55,10 +55,20 @@ public final class UnsafeArrayData extends ArrayData { - public static int calculateHeaderPortionInBytes(int numFields) { + public static int calculateHeaderPortionInBytes(int numElements) { + return (int)calculateHeaderPortionInBytes((long)numElements); + } + + public static long calculateHeaderPortionInBytes(long numFields) { return 8 + ((numFields + 63)/ 64) * 8; } + public static long calculateSizeOfUnderlyingByteArray(long numFields, int elementSize) { + long size = UnsafeArrayData.calculateHeaderPortionInBytes(numFields) + + ByteArrayMethods.roundNumberOfBytesToNearestWord(numFields * elementSize); + return size; + } + private Object baseObject; private long baseOffset; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 0f81e60586b4..1426280b96bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -303,13 +303,15 @@ case class ArrayContains(left: Expression, right: Expression) since = "2.4.0") case class Flatten(child: Expression) extends UnaryExpression { - override def nullable: Boolean = child.nullable || dataType.containsNull + private val MAX_ARRAY_LENGTH = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH - override def dataType: ArrayType = { - child - .dataType.asInstanceOf[ArrayType] - .elementType.asInstanceOf[ArrayType] - } + private lazy val childDataType: ArrayType = child.dataType.asInstanceOf[ArrayType] + + override def nullable: Boolean = child.nullable || childDataType.containsNull + + override def dataType: DataType = childDataType.elementType + + lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType override def checkInputDataTypes(): TypeCheckResult = child.dataType match { case ArrayType(_: ArrayType, _) => @@ -321,26 +323,37 @@ case class Flatten(child: Expression) extends UnaryExpression { ) } - override def nullSafeEval(array: Any): Any = { - val elements = array.asInstanceOf[ArrayData].toObjectArray(dataType) + override def nullSafeEval(child: Any): Any = { + val elements = child.asInstanceOf[ArrayData].toObjectArray(dataType) if (elements.contains(null)) { null } else { - val flattened = elements.flatMap( - _.asInstanceOf[ArrayData].toObjectArray(dataType.elementType) + val arrays = elements.map( + _.asInstanceOf[ArrayData].toObjectArray(elementType) ) - new GenericArrayData(flattened) + val numberOfElements = arrays.foldLeft(0L)((sum, e) => sum + e.length) + if(numberOfElements > MAX_ARRAY_LENGTH) { + throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + + s" $numberOfElements elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.") + } + val flattenedData = new Array(numberOfElements.toInt) + var position = 0 + for(a <- arrays) { + Array.copy(a, 0, flattenedData, position, a.length) + position += a.length + } + new GenericArrayData(flattenedData) } } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, c => { - val code = if (CodeGenerator.isPrimitiveType(dataType.elementType)) { - genCodeForConcatOfPrimitiveElements(ctx, c, ev.value) - } else { - genCodeForConcatOfComplexElements(ctx, c, ev.value) - } + val code = if (CodeGenerator.isPrimitiveType(elementType)) { + genCodeForFlattenOfPrimitiveElements(ctx, c, ev.value) + } else { + genCodeForFlattenOfNonPrimitiveElements(ctx, c, ev.value) + } nullElementsProtection(ev, c, code) }) } @@ -350,7 +363,7 @@ case class Flatten(child: Expression) extends UnaryExpression { childVariableName: String, coreLogic: String): String = { s""" - |for(int z=0; z < $childVariableName.numElements(); z++) { + |for(int z=0; !${ev.isNull} && z < $childVariableName.numElements(); z++) { | ${ev.isNull} |= $childVariableName.isNullAt(z); |} |if(!${ev.isNull}) { @@ -364,15 +377,19 @@ case class Flatten(child: Expression) extends UnaryExpression { childVariableName: String) : (String, String) = { val variableName = ctx.freshName("numElements") val code = s""" - |int $variableName = 0; + |long $variableName = 0; |for(int z=0; z < $childVariableName.numElements(); z++) { | $variableName += $childVariableName.getArray(z).numElements(); |} + |if ($variableName > ${MAX_ARRAY_LENGTH}) { + | throw new RuntimeException("Unsuccessful try to flatten an array of arrays with" + + | " $variableName elements due to exceeding the array size limit $MAX_ARRAY_LENGTH."); + |} """.stripMargin (code, variableName) } - private def genCodeForConcatOfPrimitiveElements( + private def genCodeForFlattenOfPrimitiveElements( ctx: CodegenContext, childVariableName: String, arrayDataName: String): String = { @@ -381,14 +398,16 @@ case class Flatten(child: Expression) extends UnaryExpression { val counter = ctx.freshName("counter") val tempArrayDataName = ctx.freshName("tempArrayData") - val elementType = dataType.elementType val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, childVariableName) val unsafeArraySizeInBytes = s""" - |int $arraySizeName = UnsafeArrayData.calculateHeaderPortionInBytes($numElemName) + - |${classOf[ByteArrayMethods].getName}.roundNumberOfBytesToNearestWord( - | ${elementType.defaultSize} * $numElemName - |); + |long $arraySizeName = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( + | $numElemName, + | ${elementType.defaultSize}); + |if ($arraySizeName > $MAX_ARRAY_LENGTH) { + | throw new RuntimeException("Unsuccessful try to flatten an array of arrays with" + + | " $arraySizeName bytes of data due to exceeding the limit $MAX_ARRAY_LENGTH bytes."); + |} """.stripMargin val baseOffset = Platform.BYTE_ARRAY_OFFSET @@ -397,10 +416,10 @@ case class Flatten(child: Expression) extends UnaryExpression { s""" |$numElemCode |$unsafeArraySizeInBytes - |byte[] $arrayName = new byte[$arraySizeName]; + |byte[] $arrayName = new byte[(int)$arraySizeName]; |UnsafeArrayData $tempArrayDataName = new UnsafeArrayData(); |Platform.putLong($arrayName, $baseOffset, $numElemName); - |$tempArrayDataName.pointTo($arrayName, $baseOffset, $arraySizeName); + |$tempArrayDataName.pointTo($arrayName, $baseOffset, (int)$arraySizeName); |int $counter = 0; |for(int k=0; k < $childVariableName.numElements(); k++) { | ArrayData arr = $childVariableName.getArray(k); @@ -410,7 +429,7 @@ case class Flatten(child: Expression) extends UnaryExpression { | } else { | $tempArrayDataName.set$primitiveValueTypeName( | $counter, - | arr.get$primitiveValueTypeName(l) + | ${CodeGenerator.getValue("arr", elementType, "l")} | ); | } | $counter++; @@ -420,7 +439,7 @@ case class Flatten(child: Expression) extends UnaryExpression { """.stripMargin } - private def genCodeForConcatOfComplexElements( + private def genCodeForFlattenOfNonPrimitiveElements( ctx: CodegenContext, childVariableName: String, arrayDataName: String): String = { @@ -431,12 +450,12 @@ case class Flatten(child: Expression) extends UnaryExpression { s""" |$numElemCode - |Object[] $arrayName = new Object[$numElemName]; + |Object[] $arrayName = new Object[(int)$numElemName]; |int $counter = 0; |for(int k=0; k < $childVariableName.numElements(); k++) { - | Object[] arr = $childVariableName.getArray(k).array(); - | for(int l = 0; l < arr.length; l++) { - | $arrayName[$counter] = arr[l]; + | ArrayData arr = $childVariableName.getArray(k); + | for(int l = 0; l < arr.numElements(); l++) { + | $arrayName[$counter] = ${CodeGenerator.getValue("arr", elementType, "l")}; | $counter++; | } |} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 4e7b244268e1..09ecf161a4ff 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -152,11 +152,11 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Flatten(aia3), null) checkEvaluation(Flatten(aia4), null) - // Complex-type test cases + // Non-primitive-type test cases val strArrayType = ArrayType(ArrayType(StringType)) val arrArrayType = ArrayType(ArrayType(ArrayType(StringType))) - // Main test cases (complex type) + // Main test cases (non-primitive type) val asm1 = Literal.create(Seq(Seq("a"), Seq("b", "c"), Seq("d", "e", "f")), strArrayType) val asm2 = Literal.create(Seq(Seq("a", "b")), strArrayType) val asm3 = Literal.create(Seq(Seq(Seq("a", "b"), Seq("c")), Seq(Seq("d", "e"))), arrArrayType) @@ -165,7 +165,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Flatten(asm2), Seq("a", "b")) checkEvaluation(Flatten(asm3), Seq(Seq("a", "b"), Seq("c"), Seq("d", "e"))) - // Test cases with an empty array (complex type) + // Test cases with an empty array (non-primitive type) val ase1 = Literal.create(Seq(Seq.empty, Seq("a", "b"), Seq("c", "d")), strArrayType) val ase2 = Literal.create(Seq(Seq("a", "b"), Seq.empty, Seq("c", "d")), strArrayType) val ase3 = Literal.create(Seq(Seq("a", "b"), Seq("c", "d"), Seq.empty), strArrayType) @@ -180,7 +180,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Flatten(ase5), Seq.empty) checkEvaluation(Flatten(ase6), Seq.empty) - // Test cases with null elements (complex type) + // Test cases with null elements (non-primitive type) val asn1 = Literal.create(Seq(Seq(null, null, "c"), Seq(null, null)), strArrayType) val asn2 = Literal.create(Seq(Seq(null, null, null), Seq("d", null)), strArrayType) val asn3 = Literal.create(Seq(Seq(null, null), Seq(null, null)), strArrayType) @@ -189,7 +189,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Flatten(asn2), Seq(null, null, null, "d", null)) checkEvaluation(Flatten(asn3), Seq(null, null, null, null)) - // Test cases with a null array (complex type) + // Test cases with a null array (non-primitive type) val asa1 = Literal.create(Seq(null, Seq("a", "b")), strArrayType) val asa2 = Literal.create(Seq(Seq("a", "b"), null), strArrayType) val asa3 = Literal.create(Seq(null), strArrayType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 0ceb032cbd45..d50400ba9fdd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -414,6 +414,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } test("flatten function") { + val dummyFilter = (c: Column) => c.isNull || c.isNotNull // to switch codeGen on val oneRowDF = Seq((1, "a", Seq(1, 2, 3))).toDF("i", "s", "arr") // Test cases with a primitive type @@ -439,12 +440,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(null)) checkAnswer(intDF.select(flatten($"i")), intDFResult) + checkAnswer(intDF.filter(dummyFilter($"i"))select(flatten($"i")), intDFResult) checkAnswer(intDF.selectExpr("flatten(i)"), intDFResult) checkAnswer( oneRowDF.selectExpr("flatten(array(arr, array(null, 5), array(6, null)))"), Seq(Row(Seq(1, 2, 3, null, 5, 6, null)))) - // Test cases with complex types + // Test cases with non-primitive types val strDF = Seq( (Seq(Seq("a", "b"), Seq("c"), Seq("d", "e", "f"))), (Seq(Seq("a", "b"))), @@ -469,6 +471,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(null)) checkAnswer(strDF.select(flatten($"s")), strDFResult) + checkAnswer(strDF.filter(dummyFilter($"s")).select(flatten($"s")), strDFResult) checkAnswer(strDF.selectExpr("flatten(s)"), strDFResult) checkAnswer( oneRowDF.selectExpr("flatten(array(array(arr, arr), array(arr)))"), From 2ceb53b54b14eeff185d74f8283639f6825cbf44 Mon Sep 17 00:00:00 2001 From: mn-mikke Date: Mon, 16 Apr 2018 18:37:00 +0200 Subject: [PATCH 5/8] [SPARK-23821][SQL] Optimizing evaluation without codegen. --- .../expressions/collectionOperations.scala | 34 +++++++++---------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index fb102239b472..373bdd377d81 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -394,19 +394,18 @@ case class Flatten(child: Expression) extends UnaryExpression { if (elements.contains(null)) { null } else { - val arrays = elements.map( - _.asInstanceOf[ArrayData].toObjectArray(elementType) - ) - val numberOfElements = arrays.foldLeft(0L)((sum, e) => sum + e.length) - if(numberOfElements > MAX_ARRAY_LENGTH) { + val arrayData = elements.map(_.asInstanceOf[ArrayData]) + val numberOfElements = arrayData.foldLeft(0L)((sum, e) => sum + e.numElements()) + if (numberOfElements > MAX_ARRAY_LENGTH) { throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + s" $numberOfElements elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.") } val flattenedData = new Array(numberOfElements.toInt) var position = 0 - for(a <- arrays) { - Array.copy(a, 0, flattenedData, position, a.length) - position += a.length + for (ad <- arrayData) { + val arr = ad.toObjectArray(elementType) + Array.copy(arr, 0, flattenedData, position, arr.length) + position += arr.length } new GenericArrayData(flattenedData) } @@ -428,10 +427,10 @@ case class Flatten(child: Expression) extends UnaryExpression { childVariableName: String, coreLogic: String): String = { s""" - |for(int z=0; !${ev.isNull} && z < $childVariableName.numElements(); z++) { + |for (int z=0; !${ev.isNull} && z < $childVariableName.numElements(); z++) { | ${ev.isNull} |= $childVariableName.isNullAt(z); |} - |if(!${ev.isNull}) { + |if (!${ev.isNull}) { | $coreLogic |} """.stripMargin @@ -443,7 +442,7 @@ case class Flatten(child: Expression) extends UnaryExpression { val variableName = ctx.freshName("numElements") val code = s""" |long $variableName = 0; - |for(int z=0; z < $childVariableName.numElements(); z++) { + |for (int z=0; z < $childVariableName.numElements(); z++) { | $variableName += $childVariableName.getArray(z).numElements(); |} |if ($variableName > ${MAX_ARRAY_LENGTH}) { @@ -471,7 +470,8 @@ case class Flatten(child: Expression) extends UnaryExpression { | ${elementType.defaultSize}); |if ($arraySizeName > $MAX_ARRAY_LENGTH) { | throw new RuntimeException("Unsuccessful try to flatten an array of arrays with" + - | " $arraySizeName bytes of data due to exceeding the limit $MAX_ARRAY_LENGTH bytes."); + | " $arraySizeName bytes of data due to exceeding the limit $MAX_ARRAY_LENGTH" + + | " bytes for UnsafeArrayData."); |} """.stripMargin val baseOffset = Platform.BYTE_ARRAY_OFFSET @@ -486,10 +486,10 @@ case class Flatten(child: Expression) extends UnaryExpression { |Platform.putLong($arrayName, $baseOffset, $numElemName); |$tempArrayDataName.pointTo($arrayName, $baseOffset, (int)$arraySizeName); |int $counter = 0; - |for(int k=0; k < $childVariableName.numElements(); k++) { + |for (int k=0; k < $childVariableName.numElements(); k++) { | ArrayData arr = $childVariableName.getArray(k); - | for(int l = 0; l < arr.numElements(); l++) { - | if(arr.isNullAt(l)) { + | for (int l = 0; l < arr.numElements(); l++) { + | if (arr.isNullAt(l)) { | $tempArrayDataName.setNullAt($counter); | } else { | $tempArrayDataName.set$primitiveValueTypeName( @@ -517,9 +517,9 @@ case class Flatten(child: Expression) extends UnaryExpression { |$numElemCode |Object[] $arrayName = new Object[(int)$numElemName]; |int $counter = 0; - |for(int k=0; k < $childVariableName.numElements(); k++) { + |for (int k=0; k < $childVariableName.numElements(); k++) { | ArrayData arr = $childVariableName.getArray(k); - | for(int l = 0; l < arr.numElements(); l++) { + | for (int l = 0; l < arr.numElements(); l++) { | $arrayName[$counter] = ${CodeGenerator.getValue("arr", elementType, "l")}; | $counter++; | } From 207eb5a53489b0aed6a29209ee2953a78801db89 Mon Sep 17 00:00:00 2001 From: mn-mikke Date: Tue, 17 Apr 2018 14:26:03 +0200 Subject: [PATCH 6/8] [SPARK-23821][SQL] Improving codeGen. --- .../expressions/collectionOperations.scala | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index e6ab78e46a85..7873ec921101 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -483,7 +483,7 @@ case class Flatten(child: Expression) extends UnaryExpression { } else { genCodeForFlattenOfNonPrimitiveElements(ctx, c, ev.value) } - nullElementsProtection(ev, c, code) + if (childDataType.containsNull) nullElementsProtection(ev, c, code) else code }) } @@ -492,7 +492,7 @@ case class Flatten(child: Expression) extends UnaryExpression { childVariableName: String, coreLogic: String): String = { s""" - |for (int z=0; !${ev.isNull} && z < $childVariableName.numElements(); z++) { + |for (int z = 0; !${ev.isNull} && z < $childVariableName.numElements(); z++) { | ${ev.isNull} |= $childVariableName.isNullAt(z); |} |if (!${ev.isNull}) { @@ -507,12 +507,12 @@ case class Flatten(child: Expression) extends UnaryExpression { val variableName = ctx.freshName("numElements") val code = s""" |long $variableName = 0; - |for (int z=0; z < $childVariableName.numElements(); z++) { + |for (int z = 0; z < $childVariableName.numElements(); z++) { | $variableName += $childVariableName.getArray(z).numElements(); |} - |if ($variableName > ${MAX_ARRAY_LENGTH}) { - | throw new RuntimeException("Unsuccessful try to flatten an array of arrays with" + - | " $variableName elements due to exceeding the array size limit $MAX_ARRAY_LENGTH."); + |if ($variableName > $MAX_ARRAY_LENGTH) { + | throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + + | $variableName + " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH."); |} """.stripMargin (code, variableName) @@ -534,8 +534,8 @@ case class Flatten(child: Expression) extends UnaryExpression { | $numElemName, | ${elementType.defaultSize}); |if ($arraySizeName > $MAX_ARRAY_LENGTH) { - | throw new RuntimeException("Unsuccessful try to flatten an array of arrays with" + - | " $arraySizeName bytes of data due to exceeding the limit $MAX_ARRAY_LENGTH" + + | throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + + | $arraySizeName + " bytes of data due to exceeding the limit $MAX_ARRAY_LENGTH" + | " bytes for UnsafeArrayData."); |} """.stripMargin @@ -551,7 +551,7 @@ case class Flatten(child: Expression) extends UnaryExpression { |Platform.putLong($arrayName, $baseOffset, $numElemName); |$tempArrayDataName.pointTo($arrayName, $baseOffset, (int)$arraySizeName); |int $counter = 0; - |for (int k=0; k < $childVariableName.numElements(); k++) { + |for (int k = 0; k < $childVariableName.numElements(); k++) { | ArrayData arr = $childVariableName.getArray(k); | for (int l = 0; l < arr.numElements(); l++) { | if (arr.isNullAt(l)) { @@ -582,7 +582,7 @@ case class Flatten(child: Expression) extends UnaryExpression { |$numElemCode |Object[] $arrayName = new Object[(int)$numElemName]; |int $counter = 0; - |for (int k=0; k < $childVariableName.numElements(); k++) { + |for (int k = 0; k < $childVariableName.numElements(); k++) { | ArrayData arr = $childVariableName.getArray(k); | for (int l = 0; l < arr.numElements(); l++) { | $arrayName[$counter] = ${CodeGenerator.getValue("arr", elementType, "l")}; From 10849d71adc37e565287b5f7a9230d62edd2c7fc Mon Sep 17 00:00:00 2001 From: mn-mikke Date: Tue, 17 Apr 2018 14:41:53 +0200 Subject: [PATCH 7/8] [SPARK-23821][SQL] Removing extra space from the exception message. --- .../spark/sql/catalyst/expressions/collectionOperations.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 7873ec921101..43010a46682a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -463,7 +463,7 @@ case class Flatten(child: Expression) extends UnaryExpression { val numberOfElements = arrayData.foldLeft(0L)((sum, e) => sum + e.numElements()) if (numberOfElements > MAX_ARRAY_LENGTH) { throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + - s" $numberOfElements elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.") + s"$numberOfElements elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.") } val flattenedData = new Array(numberOfElements.toInt) var position = 0 From 37b68cd8e7e1b255ac026345594e8c2e61aae43b Mon Sep 17 00:00:00 2001 From: mn-mikke Date: Thu, 19 Apr 2018 14:48:56 +0200 Subject: [PATCH 8/8] [SPARK-23821][SQL] Small refactoring --- .../expressions/collectionOperations.scala | 207 +++++++++--------- 1 file changed, 104 insertions(+), 103 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 4d52ae599129..a9c5430bdb7b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -564,6 +564,110 @@ case class ArrayPosition(left: Expression, right: Expression) } } +/** + * Returns the value of index `right` in Array `left` or the value for key `right` in Map `left`. + */ +@ExpressionDescription( + usage = """ + _FUNC_(array, index) - Returns element of array at given (1-based) index. If index < 0, + accesses elements from the last to the first. Returns NULL if the index exceeds the length + of the array. + + _FUNC_(map, key) - Returns value for given key, or NULL if the key is not contained in the map + """, + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), 2); + 2 + > SELECT _FUNC_(map(1, 'a', 2, 'b'), 2); + "b" + """, + since = "2.4.0") +case class ElementAt(left: Expression, right: Expression) extends GetMapValueUtil { + + override def dataType: DataType = left.dataType match { + case ArrayType(elementType, _) => elementType + case MapType(_, valueType, _) => valueType + } + + override def inputTypes: Seq[AbstractDataType] = { + Seq(TypeCollection(ArrayType, MapType), + left.dataType match { + case _: ArrayType => IntegerType + case _: MapType => left.dataType.asInstanceOf[MapType].keyType + } + ) + } + + override def nullable: Boolean = true + + override def nullSafeEval(value: Any, ordinal: Any): Any = { + left.dataType match { + case _: ArrayType => + val array = value.asInstanceOf[ArrayData] + val index = ordinal.asInstanceOf[Int] + if (array.numElements() < math.abs(index)) { + null + } else { + val idx = if (index == 0) { + throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1") + } else if (index > 0) { + index - 1 + } else { + array.numElements() + index + } + if (left.dataType.asInstanceOf[ArrayType].containsNull && array.isNullAt(idx)) { + null + } else { + array.get(idx, dataType) + } + } + case _: MapType => + getValueEval(value, ordinal, left.dataType.asInstanceOf[MapType].keyType) + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + left.dataType match { + case _: ArrayType => + nullSafeCodeGen(ctx, ev, (eval1, eval2) => { + val index = ctx.freshName("elementAtIndex") + val nullCheck = if (left.dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($eval1.isNullAt($index)) { + | ${ev.isNull} = true; + |} else + """.stripMargin + } else { + "" + } + s""" + |int $index = (int) $eval2; + |if ($eval1.numElements() < Math.abs($index)) { + | ${ev.isNull} = true; + |} else { + | if ($index == 0) { + | throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1"); + | } else if ($index > 0) { + | $index--; + | } else { + | $index += $eval1.numElements(); + | } + | $nullCheck + | { + | ${ev.value} = ${CodeGenerator.getValue(eval1, dataType, index)}; + | } + |} + """.stripMargin + }) + case _: MapType => + doGetValueGenCode(ctx, ev, left.dataType.asInstanceOf[MapType]) + } + } + + override def prettyName: String = "element_at" +} + /** * Transforms an array of arrays into a single array. */ @@ -740,106 +844,3 @@ case class Flatten(child: Expression) extends UnaryExpression { override def prettyName: String = "flatten" } -/** - * Returns the value of index `right` in Array `left` or the value for key `right` in Map `left`. - */ -@ExpressionDescription( - usage = """ - _FUNC_(array, index) - Returns element of array at given (1-based) index. If index < 0, - accesses elements from the last to the first. Returns NULL if the index exceeds the length - of the array. - - _FUNC_(map, key) - Returns value for given key, or NULL if the key is not contained in the map - """, - examples = """ - Examples: - > SELECT _FUNC_(array(1, 2, 3), 2); - 2 - > SELECT _FUNC_(map(1, 'a', 2, 'b'), 2); - "b" - """, - since = "2.4.0") -case class ElementAt(left: Expression, right: Expression) extends GetMapValueUtil { - - override def dataType: DataType = left.dataType match { - case ArrayType(elementType, _) => elementType - case MapType(_, valueType, _) => valueType - } - - override def inputTypes: Seq[AbstractDataType] = { - Seq(TypeCollection(ArrayType, MapType), - left.dataType match { - case _: ArrayType => IntegerType - case _: MapType => left.dataType.asInstanceOf[MapType].keyType - } - ) - } - - override def nullable: Boolean = true - - override def nullSafeEval(value: Any, ordinal: Any): Any = { - left.dataType match { - case _: ArrayType => - val array = value.asInstanceOf[ArrayData] - val index = ordinal.asInstanceOf[Int] - if (array.numElements() < math.abs(index)) { - null - } else { - val idx = if (index == 0) { - throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1") - } else if (index > 0) { - index - 1 - } else { - array.numElements() + index - } - if (left.dataType.asInstanceOf[ArrayType].containsNull && array.isNullAt(idx)) { - null - } else { - array.get(idx, dataType) - } - } - case _: MapType => - getValueEval(value, ordinal, left.dataType.asInstanceOf[MapType].keyType) - } - } - - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - left.dataType match { - case _: ArrayType => - nullSafeCodeGen(ctx, ev, (eval1, eval2) => { - val index = ctx.freshName("elementAtIndex") - val nullCheck = if (left.dataType.asInstanceOf[ArrayType].containsNull) { - s""" - |if ($eval1.isNullAt($index)) { - | ${ev.isNull} = true; - |} else - """.stripMargin - } else { - "" - } - s""" - |int $index = (int) $eval2; - |if ($eval1.numElements() < Math.abs($index)) { - | ${ev.isNull} = true; - |} else { - | if ($index == 0) { - | throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1"); - | } else if ($index > 0) { - | $index--; - | } else { - | $index += $eval1.numElements(); - | } - | $nullCheck - | { - | ${ev.value} = ${CodeGenerator.getValue(eval1, dataType, index)}; - | } - |} - """.stripMargin - }) - case _: MapType => - doGetValueGenCode(ctx, ev, left.dataType.asInstanceOf[MapType]) - } - } - - override def prettyName: String = "element_at" -}