diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index cc32fac67e924..43c35bbdf383a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -123,7 +123,7 @@ object RowEncoder { inputObject :: Nil, returnNullable = false) - case t @ ArrayType(et, cn) => + case t @ ArrayType(et, containsNull) => et match { case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => StaticInvoke( @@ -132,8 +132,16 @@ object RowEncoder { "toArrayData", inputObject :: Nil, returnNullable = false) + case _ => MapObjects( - element => serializerFor(ValidateExternalType(element, et), et), + element => { + val value = serializerFor(ValidateExternalType(element, et), et) + if (!containsNull) { + AssertNotNull(value, Seq.empty) + } else { + value + } + }, inputObject, ObjectType(classOf[Object])) } @@ -155,10 +163,19 @@ object RowEncoder { ObjectType(classOf[scala.collection.Seq[_]]), returnNullable = false) val convertedValues = serializerFor(values, ArrayType(vt, valueNullable)) - NewInstance( + val nonNullOutput = NewInstance( classOf[ArrayBasedMapData], convertedKeys :: convertedValues :: Nil, - dataType = t) + dataType = t, + propagateNull = false) + + if (inputObject.nullable) { + If(IsNull(inputObject), + Literal.create(null, inputType), + nonNullOutput) + } else { + nonNullOutput + } case StructType(fields) => val nonNullOutput = CreateNamedStruct(fields.zipWithIndex.flatMap { case (field, index) => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index 1a5569a77dc7a..6ed175f86ca77 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -273,6 +273,39 @@ class RowEncoderSuite extends SparkFunSuite { assert(e4.getMessage.contains("java.lang.String is not a valid external type")) } + for { + elementType <- Seq(IntegerType, StringType) + containsNull <- Seq(true, false) + nullable <- Seq(true, false) + } { + test("RowEncoder should preserve array nullability: " + + s"ArrayType($elementType, containsNull = $containsNull), nullable = $nullable") { + val schema = new StructType().add("array", ArrayType(elementType, containsNull), nullable) + val encoder = RowEncoder(schema).resolveAndBind() + assert(encoder.serializer.length == 1) + assert(encoder.serializer.head.dataType == ArrayType(elementType, containsNull)) + assert(encoder.serializer.head.nullable == nullable) + } + } + + for { + keyType <- Seq(IntegerType, StringType) + valueType <- Seq(IntegerType, StringType) + valueContainsNull <- Seq(true, false) + nullable <- Seq(true, false) + } { + test("RowEncoder should preserve map nullability: " + + s"MapType($keyType, $valueType, valueContainsNull = $valueContainsNull), " + + s"nullable = $nullable") { + val schema = new StructType().add( + "map", MapType(keyType, valueType, valueContainsNull), nullable) + val encoder = RowEncoder(schema).resolveAndBind() + assert(encoder.serializer.length == 1) + assert(encoder.serializer.head.dataType == MapType(keyType, valueType, valueContainsNull)) + assert(encoder.serializer.head.nullable == nullable) + } + } + private def encodeDecodeTest(schema: StructType): Unit = { test(s"encode/decode: ${schema.simpleString}") { val encoder = RowEncoder(schema).resolveAndBind()