From 086e223e89ce0cf56145f4aa9a7aef5421a98810 Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Fri, 4 May 2018 20:00:40 +0200 Subject: [PATCH 1/5] [SPARK-23935][SQL] Adding map_entries function --- python/pyspark/sql/functions.py | 20 +++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 156 ++++++++++++++++++ .../CollectionExpressionsSuite.scala | 23 +++ .../expressions/ExpressionEvalHelper.scala | 3 + .../org/apache/spark/sql/functions.scala | 7 + .../spark/sql/DataFrameFunctionsSuite.scala | 44 +++++ 7 files changed, 254 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index ad4bd6f5089e..c5a90ee33443 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2273,6 +2273,26 @@ def map_values(col): return Column(sc._jvm.functions.map_values(_to_java_column(col))) +@since(2.4) +def map_entries(col): + """ + Collection function: Returns an unordered array of all entries in the given map. + + :param col: name of column or expression + + >>> from pyspark.sql.functions import map_entries + >>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as data") + >>> df.select(map_entries("data").alias("entries")).show() + +----------------+ + | entries| + +----------------+ + |[[1, a], [2, b]]| + +----------------+ + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.map_entries(_to_java_column(col))) + + # ---------------------------- User Defined Function ---------------------------------- class PandasUDFType(object): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 51bb6b0abe40..23149b897537 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -408,6 +408,7 @@ object FunctionRegistry { expression[ElementAt]("element_at"), expression[MapKeys]("map_keys"), expression[MapValues]("map_values"), + expression[MapEntries]("map_entries"), expression[Size]("size"), expression[Size]("cardinality"), expression[SortArray]("sort_array"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 6d63a531e3b7..1e583f5c2f96 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -118,6 +118,162 @@ case class MapValues(child: Expression) override def prettyName: String = "map_values" } +/** + * Returns an unordered array of all entries in the given map. + */ +@ExpressionDescription( + usage = "_FUNC_(map) - Returns an unordered array of all entries in the given map.", + examples = """ + Examples: + > SELECT _FUNC_(map(1, 'a', 2, 'b')); + [(1,"a"),(2,"b")] + """, + since = "2.4.0") +case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(MapType) + + lazy val childDataType: MapType = child.dataType.asInstanceOf[MapType] + + override def dataType: DataType = { + ArrayType( + StructType( + StructField("key", childDataType.keyType, false) :: + StructField("value", childDataType.valueType, childDataType.valueContainsNull) :: + Nil), + false) + } + + override protected def nullSafeEval(input: Any): Any = { + val childMap = input.asInstanceOf[MapData] + val keys = childMap.keyArray() + val values = childMap.valueArray() + val length = childMap.numElements() + val resultData = new Array[AnyRef](length) + var i = 0; + while (i < length) { + val key = keys.get(i, childDataType.keyType) + val value = values.get(i, childDataType.valueType) + val row = new GenericInternalRow(Array[Any](key, value)) + resultData.update(i, row) + i += 1 + } + new GenericArrayData(resultData) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, c => { + val numElements = ctx.freshName("numElements") + val keys = ctx.freshName("keys") + val values = ctx.freshName("values") + val isKeyPrimitive = CodeGenerator.isPrimitiveType(childDataType.keyType) + val isValuePrimitive = CodeGenerator.isPrimitiveType(childDataType.valueType) + val code = if (isKeyPrimitive && isValuePrimitive) { + genCodeForPrimitiveElements(ctx, keys, values, ev.value, numElements) + } else { + genCodeForAnyElements(ctx, keys, values, ev.value, numElements) + } + s""" + |final int $numElements = $c.numElements(); + |final ArrayData $keys = $c.keyArray(); + |final ArrayData $values = $c.valueArray(); + |$code + """.stripMargin + }) + } + + private def getKey(varName: String) = CodeGenerator.getValue(varName, childDataType.keyType, "z") + + private def getValue(varName: String) = { + CodeGenerator.getValue(varName, childDataType.valueType, "z") + } + + private def genCodeForPrimitiveElements( + ctx: CodegenContext, + keys: String, + values: String, + arrayData: String, + numElements: String): String = { + val byteArraySize = ctx.freshName("byteArraySize") + val data = ctx.freshName("byteArray") + val unsafeRow = ctx.freshName("unsafeRow") + val structSize = ctx.freshName("structSize") + val unsafeArrayData = ctx.freshName("unsafeArrayData") + val structsOffset = ctx.freshName("structsOffset") + val calculateArraySize = "UnsafeArrayData.calculateSizeOfUnderlyingByteArray" + val calculateHeader = "UnsafeArrayData.calculateHeaderPortionInBytes" + + val baseOffset = Platform.BYTE_ARRAY_OFFSET + val longSize = LongType.defaultSize + val keyTypeName = CodeGenerator.primitiveTypeName(childDataType.keyType) + val valueTypeName = CodeGenerator.primitiveTypeName(childDataType.keyType) + + val valueAssignment = s"$unsafeRow.set$valueTypeName(1, ${getValue(values)});" + val valueAssignmentChecked = if (childDataType.valueContainsNull) { + s""" + |if ($values.isNullAt(z)) { + | $unsafeRow.setNullAt(1); + |} else { + | $valueAssignment + |} + """.stripMargin + } else { + valueAssignment + } + + s""" + |final int $structSize = ${UnsafeRow.calculateBitSetWidthInBytes(2) + longSize * 2}; + |final long $byteArraySize = $calculateArraySize($numElements, $longSize + $structSize); + |final int $structsOffset = $calculateHeader($numElements) + $numElements * $longSize; + |if ($byteArraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | ${genCodeForAnyElements(ctx, keys, values, arrayData, numElements)} + |} else { + | final byte[] $data = new byte[(int)$byteArraySize]; + | UnsafeArrayData $unsafeArrayData = new UnsafeArrayData(); + | Platform.putLong($data, $baseOffset, $numElements); + | $unsafeArrayData.pointTo($data, $baseOffset, (int)$byteArraySize); + | UnsafeRow $unsafeRow = new UnsafeRow(2); + | for (int z = 0; z < $numElements; z++) { + | long offset = $structsOffset + z * $structSize; + | $unsafeArrayData.setLong(z, (offset << 32) + $structSize); + | $unsafeRow.pointTo($data, $baseOffset + offset, $structSize); + | $unsafeRow.set$keyTypeName(0, ${getKey(keys)}); + | $valueAssignmentChecked + | } + | $arrayData = $unsafeArrayData; + |} + """.stripMargin + } + + private def genCodeForAnyElements( + ctx: CodegenContext, + keys: String, + values: String, + arrayData: String, + numElements: String): String = { + val genericArrayClass = classOf[GenericArrayData].getName + val rowClass = classOf[GenericInternalRow].getName + val data = ctx.freshName("internalRowArray") + + val isValuePrimitive = CodeGenerator.isPrimitiveType(childDataType.valueType) + val getValueWithCheck = if (childDataType.valueContainsNull && isValuePrimitive) { + s"$values.isNullAt(z) ? null : (Object)${getValue(values)}" + } else { + getValue(values) + } + + s""" + |final Object[] $data = new Object[$numElements]; + |for (int z = 0; z < $numElements; z++) { + | $data[z] = new $rowClass(new Object[]{${getKey(keys)}, $getValueWithCheck}); + |} + |$arrayData = new $genericArrayClass($data); + """.stripMargin + } + + override def prettyName: String = "map_entries" +} + /** * Sorts the input array in ascending / descending order according to the natural ordering of * the array elements and returns it. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 7048d93fd564..5b6123971112 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -56,6 +57,28 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(MapValues(m2), null) } + test("MapEntries") { + def r(values: Any*): InternalRow = create_row(values: _*) + + // Primitive-type keys/values + val mi0 = Literal.create(Map(1 -> 1, 2 -> null, 3 -> 2), MapType(IntegerType, IntegerType)) + val mi1 = Literal.create(Map[Int, Int](), MapType(IntegerType, IntegerType)) + val mi2 = Literal.create(null, MapType(IntegerType, IntegerType)) + + checkEvaluation(MapEntries(mi0), Seq(r(1, 1), r(2, null), r(3, 2))) + checkEvaluation(MapEntries(mi1), Seq.empty) + checkEvaluation(MapEntries(mi2), null) + + // Non-primitive-type keys/values + val ms0 = Literal.create(Map("a" -> "a", "b" -> null, "c" -> "b"), MapType(StringType, StringType)) + val ms1 = Literal.create(Map[Int, Int](), MapType(StringType, StringType)) + val ms2 = Literal.create(null, MapType(StringType, StringType)) + + checkEvaluation(MapEntries(ms0), Seq(r("a", "a"), r("b", null), r("c", "b"))) + checkEvaluation(MapEntries(ms1), Seq.empty) + checkEvaluation(MapEntries(ms2), null) + } + test("Sort Array") { val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType)) val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index b4bf6d7107d7..e739f1a6b4cf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -98,6 +98,9 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { if (expected.isNaN) result.isNaN else expected == result case (result: Float, expected: Float) => if (expected.isNaN) result.isNaN else expected == result + case (result: UnsafeRow, expected: GenericInternalRow) => + val structType = exprDataType.asInstanceOf[StructType] + result.toSeq(structType) == expected.toSeq(structType) case (result: Row, expected: InternalRow) => result.toSeq == expected.toSeq(result.schema) case _ => result == expected diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index d2e22fa35551..0f1c8783918d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3392,6 +3392,13 @@ object functions { */ def map_values(e: Column): Column = withExpr { MapValues(e.expr) } + /** + * Returns an unordered array of all entries in the given map. + * @group collection_funcs + * @since 2.4.0 + */ + def map_entries(e: Column): Column = withExpr { MapEntries(e.expr) } + // scalastyle:off line.size.limit // scalastyle:off parameter.number diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index a5163accb1bb..49c8fa73e255 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -381,6 +381,50 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("map_entries") { + val dummyFilter = (c: Column) => c.isNotNull || c.isNull + + // Primitive-type elements + val idf = Seq( + Map[Int, Int](1 -> 100, 2 -> 200, 3 -> 300), + Map[Int, Int](), + null + ).toDF("m") + val iExpected = Seq( + Row(Seq(Row(1, 100), Row(2, 200), Row(3, 300))), + Row(Seq.empty), + Row(null) + ) + + checkAnswer(idf.select(map_entries('m)), iExpected) + checkAnswer(idf.selectExpr("map_entries(m)"), iExpected) + checkAnswer(idf.filter(dummyFilter('m)).select(map_entries('m)), iExpected) + checkAnswer( + spark.range(1).selectExpr("map_entries(map(1, null, 2, null))"), + Seq(Row(Seq(Row(1, null), Row(2, null))))) + checkAnswer( + spark.range(1).filter(dummyFilter('id)).selectExpr("map_entries(map(1, null, 2, null))"), + Seq(Row(Seq(Row(1, null), Row(2, null))))) + + // Non-primitive-type elements + val sdf = Seq( + Map[String, String]("a" -> "f", "b" -> "o", "c" -> "o"), + Map[String, String]("a" -> null, "b" -> null), + Map[String, String](), + null + ).toDF("m") + val sExpected = Seq( + Row(Seq(Row("a", "f"), Row("b", "o"), Row("c", "o"))), + Row(Seq(Row("a", null), Row("b", null))), + Row(Seq.empty), + Row(null) + ) + + checkAnswer(sdf.select(map_entries('m)), sExpected) + checkAnswer(sdf.selectExpr("map_entries(m)"), sExpected) + checkAnswer(sdf.filter(dummyFilter('m)).select(map_entries('m)), sExpected) + } + test("array contains function") { val df = Seq( (Seq[Int](1, 2), "x"), From b9e240944e41674a25fd80e4c80b2f7c83af14ef Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Sat, 5 May 2018 00:14:56 +0200 Subject: [PATCH 2/5] [SPARK-23935][SQL] Fixing a Scala style problem in tests --- .../sql/catalyst/expressions/CollectionExpressionsSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 5b6123971112..3785c5e6e21b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -70,11 +70,11 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(MapEntries(mi2), null) // Non-primitive-type keys/values - val ms0 = Literal.create(Map("a" -> "a", "b" -> null, "c" -> "b"), MapType(StringType, StringType)) + val ms0 = Literal.create(Map("a" -> "c", "b" -> null), MapType(StringType, StringType)) val ms1 = Literal.create(Map[Int, Int](), MapType(StringType, StringType)) val ms2 = Literal.create(null, MapType(StringType, StringType)) - checkEvaluation(MapEntries(ms0), Seq(r("a", "a"), r("b", null), r("c", "b"))) + checkEvaluation(MapEntries(ms0), Seq(r("a", "c"), r("b", null))) checkEvaluation(MapEntries(ms1), Seq.empty) checkEvaluation(MapEntries(ms2), null) } From d05ad9be40064c61b05d838b3ba96b02267d5ee1 Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Mon, 7 May 2018 15:28:25 +0200 Subject: [PATCH 3/5] [SPARK-23935][SQL] Inlining struct size and moving structOffeset into else branch. --- .../sql/catalyst/expressions/collectionOperations.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 58337542d2ff..0f2fe1d6a426 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types.{ByteArray, UTF8String} @@ -197,7 +198,6 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp val byteArraySize = ctx.freshName("byteArraySize") val data = ctx.freshName("byteArray") val unsafeRow = ctx.freshName("unsafeRow") - val structSize = ctx.freshName("structSize") val unsafeArrayData = ctx.freshName("unsafeArrayData") val structsOffset = ctx.freshName("structsOffset") val calculateArraySize = "UnsafeArrayData.calculateSizeOfUnderlyingByteArray" @@ -205,6 +205,7 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp val baseOffset = Platform.BYTE_ARRAY_OFFSET val longSize = LongType.defaultSize + val structSize = UnsafeRow.calculateBitSetWidthInBytes(2) + longSize * 2 val keyTypeName = CodeGenerator.primitiveTypeName(childDataType.keyType) val valueTypeName = CodeGenerator.primitiveTypeName(childDataType.keyType) @@ -222,12 +223,11 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp } s""" - |final int $structSize = ${UnsafeRow.calculateBitSetWidthInBytes(2) + longSize * 2}; - |final long $byteArraySize = $calculateArraySize($numElements, $longSize + $structSize); - |final int $structsOffset = $calculateHeader($numElements) + $numElements * $longSize; + |final long $byteArraySize = $calculateArraySize($numElements, ${longSize + structSize}); |if ($byteArraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { | ${genCodeForAnyElements(ctx, keys, values, arrayData, numElements)} |} else { + | final int $structsOffset = $calculateHeader($numElements) + $numElements * $longSize; | final byte[] $data = new byte[(int)$byteArraySize]; | UnsafeArrayData $unsafeArrayData = new UnsafeArrayData(); | Platform.putLong($data, $baseOffset, $numElements); From 56ff20ac977ca1a305e96a7582789e2e75e6718c Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Sun, 13 May 2018 22:50:58 +0200 Subject: [PATCH 4/5] [SPARK-23935][SQL] Introducing long constant for struct size. --- .../sql/catalyst/expressions/collectionOperations.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 0f2fe1d6a426..5ca67e6eb767 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -206,6 +206,7 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp val baseOffset = Platform.BYTE_ARRAY_OFFSET val longSize = LongType.defaultSize val structSize = UnsafeRow.calculateBitSetWidthInBytes(2) + longSize * 2 + val structSizeAsLong = structSize + "L" val keyTypeName = CodeGenerator.primitiveTypeName(childDataType.keyType) val valueTypeName = CodeGenerator.primitiveTypeName(childDataType.keyType) @@ -234,8 +235,8 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp | $unsafeArrayData.pointTo($data, $baseOffset, (int)$byteArraySize); | UnsafeRow $unsafeRow = new UnsafeRow(2); | for (int z = 0; z < $numElements; z++) { - | long offset = $structsOffset + z * $structSize; - | $unsafeArrayData.setLong(z, (offset << 32) + $structSize); + | long offset = $structsOffset + z * $structSizeAsLong; + | $unsafeArrayData.setLong(z, (offset << 32) + $structSizeAsLong); | $unsafeRow.pointTo($data, $baseOffset + offset, $structSize); | $unsafeRow.set$keyTypeName(0, ${getKey(keys)}); | $valueAssignmentChecked From 1bd0d5ea78908820140bd13c624a8e90f1e737fe Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Mon, 14 May 2018 17:19:30 +0200 Subject: [PATCH 5/5] [SPARK-23935][SQL] Addressing review comments. --- .../sql/catalyst/expressions/UnsafeRow.java | 2 + .../expressions/codegen/CodeGenerator.scala | 34 +++++++++++ .../expressions/collectionOperations.scala | 56 +++++++++---------- 3 files changed, 62 insertions(+), 30 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 29a1411241cf..469b0e60cc9a 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -62,6 +62,8 @@ */ public final class UnsafeRow extends InternalRow implements Externalizable, KryoSerializable { + public static final int WORD_SIZE = 8; + ////////////////////////////////////////////////////////////////////////////// // Static methods ////////////////////////////////////////////////////////////////////////////// diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 4dda52529425..d382d9aace10 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -764,6 +764,40 @@ class CodegenContext { """.stripMargin } + /** + * Generates code creating a [[UnsafeArrayData]]. The generated code executes + * a provided fallback when the size of backing array would exceed the array size limit. + * @param arrayName a name of the array to create + * @param numElements a piece of code representing the number of elements the array should contain + * @param elementSize a size of an element in bytes + * @param bodyCode a function generating code that fills up the [[UnsafeArrayData]] + * and getting the backing array as a parameter + * @param fallbackCode a piece of code executed when the array size limit is exceeded + */ + def createUnsafeArrayWithFallback( + arrayName: String, + numElements: String, + elementSize: Int, + bodyCode: String => String, + fallbackCode: String): String = { + val arraySize = freshName("size") + val arrayBytes = freshName("arrayBytes") + s""" + |final long $arraySize = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( + | $numElements, + | $elementSize); + |if ($arraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | $fallbackCode + |} else { + | final byte[] $arrayBytes = new byte[(int)$arraySize]; + | UnsafeArrayData $arrayName = new UnsafeArrayData(); + | Platform.putLong($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, $numElements); + | $arrayName.pointTo($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, (int)$arraySize); + | ${bodyCode(arrayBytes)} + |} + """.stripMargin + } + /** * Generates code to do null safe execution, i.e. only execute the code when the input is not * null by adding null check if necessary. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 5ca67e6eb767..baeefae57099 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -195,17 +195,14 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp values: String, arrayData: String, numElements: String): String = { - val byteArraySize = ctx.freshName("byteArraySize") - val data = ctx.freshName("byteArray") val unsafeRow = ctx.freshName("unsafeRow") val unsafeArrayData = ctx.freshName("unsafeArrayData") val structsOffset = ctx.freshName("structsOffset") - val calculateArraySize = "UnsafeArrayData.calculateSizeOfUnderlyingByteArray" val calculateHeader = "UnsafeArrayData.calculateHeaderPortionInBytes" val baseOffset = Platform.BYTE_ARRAY_OFFSET - val longSize = LongType.defaultSize - val structSize = UnsafeRow.calculateBitSetWidthInBytes(2) + longSize * 2 + val wordSize = UnsafeRow.WORD_SIZE + val structSize = UnsafeRow.calculateBitSetWidthInBytes(2) + wordSize * 2 val structSizeAsLong = structSize + "L" val keyTypeName = CodeGenerator.primitiveTypeName(childDataType.keyType) val valueTypeName = CodeGenerator.primitiveTypeName(childDataType.keyType) @@ -223,27 +220,26 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp valueAssignment } - s""" - |final long $byteArraySize = $calculateArraySize($numElements, ${longSize + structSize}); - |if ($byteArraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { - | ${genCodeForAnyElements(ctx, keys, values, arrayData, numElements)} - |} else { - | final int $structsOffset = $calculateHeader($numElements) + $numElements * $longSize; - | final byte[] $data = new byte[(int)$byteArraySize]; - | UnsafeArrayData $unsafeArrayData = new UnsafeArrayData(); - | Platform.putLong($data, $baseOffset, $numElements); - | $unsafeArrayData.pointTo($data, $baseOffset, (int)$byteArraySize); - | UnsafeRow $unsafeRow = new UnsafeRow(2); - | for (int z = 0; z < $numElements; z++) { - | long offset = $structsOffset + z * $structSizeAsLong; - | $unsafeArrayData.setLong(z, (offset << 32) + $structSizeAsLong); - | $unsafeRow.pointTo($data, $baseOffset + offset, $structSize); - | $unsafeRow.set$keyTypeName(0, ${getKey(keys)}); - | $valueAssignmentChecked - | } - | $arrayData = $unsafeArrayData; - |} - """.stripMargin + val assignmentLoop = (byteArray: String) => + s""" + |final int $structsOffset = $calculateHeader($numElements) + $numElements * $wordSize; + |UnsafeRow $unsafeRow = new UnsafeRow(2); + |for (int z = 0; z < $numElements; z++) { + | long offset = $structsOffset + z * $structSizeAsLong; + | $unsafeArrayData.setLong(z, (offset << 32) + $structSizeAsLong); + | $unsafeRow.pointTo($byteArray, $baseOffset + offset, $structSize); + | $unsafeRow.set$keyTypeName(0, ${getKey(keys)}); + | $valueAssignmentChecked + |} + |$arrayData = $unsafeArrayData; + """.stripMargin + + ctx.createUnsafeArrayWithFallback( + unsafeArrayData, + numElements, + structSize + wordSize, + assignmentLoop, + genCodeForAnyElements(ctx, keys, values, arrayData, numElements)) } private def genCodeForAnyElements( @@ -258,10 +254,10 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp val isValuePrimitive = CodeGenerator.isPrimitiveType(childDataType.valueType) val getValueWithCheck = if (childDataType.valueContainsNull && isValuePrimitive) { - s"$values.isNullAt(z) ? null : (Object)${getValue(values)}" - } else { - getValue(values) - } + s"$values.isNullAt(z) ? null : (Object)${getValue(values)}" + } else { + getValue(values) + } s""" |final Object[] $data = new Object[$numElements];