From f64f570430648983fd8a3664c526d96de710ea64 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 23 Jun 2016 15:33:04 +0900 Subject: [PATCH 1/4] Add tests to check if RowEncoder preserves array/map nullability. --- .../catalyst/encoders/RowEncoderSuite.scala | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index 2e513ea22c15..bcd76bddb072 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -247,6 +247,39 @@ class RowEncoderSuite extends SparkFunSuite { assert(e4.getMessage.contains("java.lang.String is not a valid external type")) } + for { + elementType <- Seq(IntegerType, StringType) + containsNull <- Seq(true, false) + nullable <- Seq(true, false) + } { + test("RowEncoder should preserve array nullability: " + + s"ArrayType($elementType, containsNull = $containsNull), nullable = $nullable") { + val schema = new StructType().add("array", ArrayType(elementType, containsNull), nullable) + val encoder = RowEncoder(schema).resolveAndBind() + assert(encoder.serializer.length == 1) + assert(encoder.serializer.head.dataType == ArrayType(elementType, containsNull)) + assert(encoder.serializer.head.nullable == nullable) + } + } + + for { + keyType <- Seq(IntegerType, StringType) + valueType <- Seq(IntegerType, StringType) + valueContainsNull <- Seq(true, false) + nullable <- Seq(true, false) + } { + test("RowEncoder should preserve map nullability: " + + s"MapType($keyType, $valueType, valueContainsNull = $valueContainsNull), " + + s"nullable = $nullable") { + val schema = new StructType().add( + "map", MapType(keyType, valueType, valueContainsNull), nullable) + val encoder = RowEncoder(schema).resolveAndBind() + assert(encoder.serializer.length == 1) + assert(encoder.serializer.head.dataType == MapType(keyType, valueType, valueContainsNull)) + assert(encoder.serializer.head.nullable == nullable) + } + } + private def encodeDecodeTest(schema: StructType): Unit = { test(s"encode/decode: ${schema.simpleString}") { val encoder = RowEncoder(schema).resolveAndBind() From 093a9fa9039e1403e5f9ec9e0ab5f422f97f0fc4 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 23 Jun 2016 15:44:39 +0900 Subject: [PATCH 2/4] Modify RowEncoder and MapObjects to preserve array/map nullability. --- .../sql/catalyst/encoders/RowEncoder.scala | 38 ++++++++++++++++--- .../expressions/objects/objects.scala | 23 +++++++---- 2 files changed, 48 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 67fca153b551..e8c23f9b0996 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -119,15 +119,32 @@ object RowEncoder { "fromString", inputObject :: Nil) - case t @ ArrayType(et, _) => et match { + case t @ ArrayType(et, containsNull) => et match { case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => // TODO: validate input type for primitive array. - NewInstance( + val nonNullOutput = NewInstance( classOf[GenericArrayData], inputObject :: Nil, - dataType = t) + dataType = t, + propagateNull = false) + + if (inputObject.nullable) { + If(IsNull(inputObject), + Literal.create(null, inputType), + nonNullOutput) + } else { + nonNullOutput + } + case _ => MapObjects( - element => serializerFor(ValidateExternalType(element, et), et), + { element => + val value = serializerFor(ValidateExternalType(element, et), et) + if (!containsNull) { + AssertNotNull(value, Seq.empty) + } else { + value + } + }, inputObject, ObjectType(classOf[Object])) } @@ -147,10 +164,19 @@ object RowEncoder { ObjectType(classOf[scala.collection.Seq[_]])) val convertedValues = serializerFor(values, ArrayType(vt, valueNullable)) - NewInstance( + val nonNullOutput = NewInstance( classOf[ArrayBasedMapData], convertedKeys :: convertedValues :: Nil, - dataType = t) + dataType = t, + propagateNull = false) + + if (inputObject.nullable) { + If(IsNull(inputObject), + Literal.create(null, inputType), + nonNullOutput) + } else { + nonNullOutput + } case StructType(fields) => val nonNullOutput = CreateNamedStruct(fields.zipWithIndex.flatMap { case (field, index) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index c597a2a70944..c820e381b018 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -376,14 +376,15 @@ case class MapObjects private( lambdaFunction: Expression, inputData: Expression) extends Expression with NonSQLExpression { - override def nullable: Boolean = true + override def nullable: Boolean = inputData.nullable override def children: Seq[Expression] = lambdaFunction :: inputData :: Nil override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported") - override def dataType: DataType = ArrayType(lambdaFunction.dataType) + override def dataType: DataType = + ArrayType(lambdaFunction.dataType, containsNull = lambdaFunction.nullable) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val elementJavaType = ctx.javaType(loopVar.dataType) @@ -450,6 +451,18 @@ case class MapObjects private( case _ => s"${loopVar.isNull} = ${loopVar.value} == null;" } + val setValue = if (lambdaFunction.nullable) { + s""" + if (${genFunction.isNull}) { + $convertedArray[$loopIndex] = null; + } else { + $convertedArray[$loopIndex] = ${genFunction.value}; + } + """ + } else { + s"$convertedArray[$loopIndex] = ${genFunction.value};" + } + val code = s""" ${genInputData.code} ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; @@ -466,11 +479,7 @@ case class MapObjects private( $loopNullCheck ${genFunction.code} - if (${genFunction.isNull}) { - $convertedArray[$loopIndex] = null; - } else { - $convertedArray[$loopIndex] = ${genFunction.value}; - } + $setValue $loopIndex += 1; } From be6ed48d5ed142f6c6101b330e04feb4eb7207e1 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 23 Jun 2016 18:42:45 +0900 Subject: [PATCH 3/4] Fix nullability of NewInstance. --- .../spark/sql/catalyst/encoders/RowEncoder.scala | 13 ++----------- .../sql/catalyst/expressions/objects/objects.scala | 4 ++-- 2 files changed, 4 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index e8c23f9b0996..e1a94a9cf18c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -122,19 +122,10 @@ object RowEncoder { case t @ ArrayType(et, containsNull) => et match { case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => // TODO: validate input type for primitive array. - val nonNullOutput = NewInstance( + NewInstance( classOf[GenericArrayData], inputObject :: Nil, - dataType = t, - propagateNull = false) - - if (inputObject.nullable) { - If(IsNull(inputObject), - Literal.create(null, inputType), - nonNullOutput) - } else { - nonNullOutput - } + dataType = t) case _ => MapObjects( { element => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index c820e381b018..e21031da5b4a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -213,7 +213,7 @@ case class NewInstance( outerPointer: Option[() => AnyRef]) extends Expression with NonSQLExpression { private val className = cls.getName - override def nullable: Boolean = propagateNull + override def nullable: Boolean = propagateNull && arguments.exists(_.nullable) override def children: Seq[Expression] = arguments @@ -238,7 +238,7 @@ case class NewInstance( val outer = outerPointer.map(func => Literal.fromObject(func()).genCode(ctx)) var isNull = ev.isNull - val setIsNull = if (propagateNull && arguments.nonEmpty) { + val setIsNull = if (nullable) { s"final boolean $isNull = ${argGen.map(_.isNull).mkString(" || ")};" } else { isNull = "false" From 9caf4257d780b43e049dbbef75d525213aad1d18 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 5 Jul 2017 16:07:45 +0900 Subject: [PATCH 4/4] Address a comment. --- .../org/apache/spark/sql/catalyst/encoders/RowEncoder.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 6d9172fbf92a..43c35bbdf383 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -134,7 +134,7 @@ object RowEncoder { returnNullable = false) case _ => MapObjects( - { element => + element => { val value = serializerFor(ValidateExternalType(element, et), et) if (!containsNull) { AssertNotNull(value, Seq.empty)