From 3875e1fe20f7032c3da007bcdccd9f2c627710db Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sat, 18 Jun 2016 15:18:33 +0900 Subject: [PATCH] eliminate nullcheck code if all of the elements do not have null add unit tests --- .../codegen/GenerateUnsafeProjection.scala | 44 ++++++++++++------- .../spark/sql/DataFrameComplexTypeSuite.scala | 34 ++++++++++++++ 2 files changed, 62 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 5efba4b3a6087..338b368122f8f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -117,22 +117,22 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); """ - case a @ ArrayType(et, _) => + case a @ ArrayType(et, cn) => s""" // Remember the current cursor so that we can calculate how many bytes are // written later. final int $tmpCursor = $bufferHolder.cursor; - ${writeArrayToBuffer(ctx, input.value, et, bufferHolder)} + ${writeArrayToBuffer(ctx, input.value, et, cn, bufferHolder)} $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); $rowWriter.alignToWords($bufferHolder.cursor - $tmpCursor); """ - case m @ MapType(kt, vt, _) => + case m @ MapType(kt, vt, cn) => s""" // Remember the current cursor so that we can calculate how many bytes are // written later. final int $tmpCursor = $bufferHolder.cursor; - ${writeMapToBuffer(ctx, input.value, kt, vt, bufferHolder)} + ${writeMapToBuffer(ctx, input.value, kt, vt, cn, bufferHolder)} $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); $rowWriter.alignToWords($bufferHolder.cursor - $tmpCursor); """ @@ -173,6 +173,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ctx: CodegenContext, input: String, elementType: DataType, + containsNull: Boolean, bufferHolder: String): String = { val arrayWriterClass = classOf[UnsafeArrayWriter].getName val arrayWriter = ctx.freshName("arrayWriter") @@ -202,16 +203,16 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ${writeStructToBuffer(ctx, element, t.map(_.dataType), bufferHolder)} """ - case a @ ArrayType(et, _) => + case a @ ArrayType(et, cn) => s""" $arrayWriter.setOffset($index); - ${writeArrayToBuffer(ctx, element, et, bufferHolder)} + ${writeArrayToBuffer(ctx, element, et, cn, bufferHolder)} """ - case m @ MapType(kt, vt, _) => + case m @ MapType(kt, vt, cn) => s""" $arrayWriter.setOffset($index); - ${writeMapToBuffer(ctx, element, kt, vt, bufferHolder)} + ${writeMapToBuffer(ctx, element, kt, vt, cn, bufferHolder)} """ case t: DecimalType => @@ -222,6 +223,21 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case _ => s"$arrayWriter.write($index, $element);" } + val storeElement = if (containsNull) { + s""" + if ($input.isNullAt($index)) { + $arrayWriter.setNullAt($index); + } else { + final $jt $element = ${ctx.getValue(input, et, index)}; + $writeElement + } + """ + } else { + s""" + final $jt $element = ${ctx.getValue(input, et, index)}; + $writeElement + """ + } s""" if ($input instanceof UnsafeArrayData) { ${writeUnsafeData(ctx, s"((UnsafeArrayData) $input)", bufferHolder)} @@ -230,12 +246,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro $arrayWriter.initialize($bufferHolder, $numElements, $fixedElementSize); for (int $index = 0; $index < $numElements; $index++) { - if ($input.isNullAt($index)) { - $arrayWriter.setNullAt($index); - } else { - final $jt $element = ${ctx.getValue(input, et, index)}; - $writeElement - } + $storeElement } } """ @@ -247,6 +258,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro input: String, keyType: DataType, valueType: DataType, + valueContainsNull: Boolean, bufferHolder: String): String = { val keys = ctx.freshName("keys") val values = ctx.freshName("values") @@ -268,11 +280,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // Remember the current cursor so that we can write numBytes of key array later. final int $tmpCursor = $bufferHolder.cursor; - ${writeArrayToBuffer(ctx, keys, keyType, bufferHolder)} + ${writeArrayToBuffer(ctx, keys, keyType, false, bufferHolder)} // Write the numBytes of key array into the first 4 bytes. Platform.putInt($bufferHolder.buffer, $tmpCursor - 4, $bufferHolder.cursor - $tmpCursor); - ${writeArrayToBuffer(ctx, values, valueType, bufferHolder)} + ${writeArrayToBuffer(ctx, values, valueType, valueContainsNull, bufferHolder)} } """ } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala index 72f676e6225ee..bb05c5e2cf454 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala @@ -26,6 +26,40 @@ import org.apache.spark.sql.test.SharedSQLContext class DataFrameComplexTypeSuite extends QueryTest with SharedSQLContext { import testImplicits._ + test("primitive type on array") { + val rows = sparkContext.parallelize(Seq(1, 2), 1).toDF("v"). + selectExpr("Array(v + 2, v + 3)").collect + QueryTest.sameRows(Seq(Row(Array(3, 4)), Row(Array(4, 5))), rows.toSeq) + } + + test("primitive type and null on array") { + val rows = sparkContext.parallelize(Seq(1, 2), 1).toDF("v"). + selectExpr("Array(v + 2, null, v + 3)").collect + QueryTest.sameRows(Seq(Row(Array(3, null, 4)), Row(Array(4, null, 5))), rows.toSeq) + } + + test("array with null on array") { + val rows = sparkContext.parallelize(Seq(1, 2), 1).toDF("v"). + selectExpr("Array(Array(v, v + 1)," + + "null," + + "Array(v, v - 1))").collect + QueryTest.sameRows(Seq( + Row(Array(Array(1, 2), null, Array(3, 4))), + Row(Array(Array(2, 3), null, Array(4, 5)))), rows.toSeq) + } + + test("primitive type on map") { + val rows = sparkContext.parallelize(Seq(1, 2), 1).toDF("v"). + selectExpr("map(v + 3, v + 4)").collect + QueryTest.sameRows(Seq(Row(Map(4 -> 5)), Row(Map(5 -> 6))), rows.toSeq) + } + + test("map with null value on map") { + val rows = sparkContext.parallelize(Seq(1, 2), 1).toDF("v"). + selectExpr("map(v, null)").collect + QueryTest.sameRows(Seq(Row(Map(1 -> null)), Row(Map(2 -> null))), rows.toSeq) + } + test("UDF on struct") { val f = udf((a: String) => a) val df = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b")