Skip to content

Commit 873f3ad

Browse files
ueshincloud-fan
authored andcommitted
[SPARK-16167][SQL] RowEncoder should preserve array/map type nullability.
## What changes were proposed in this pull request? Currently `RowEncoder` doesn't preserve nullability of `ArrayType` or `MapType`. It returns always `containsNull = true` for `ArrayType`, `valueContainsNull = true` for `MapType` and also the nullability of itself is always `true`. This pr fixes the nullability of them. ## How was this patch tested? Add tests to check if `RowEncoder` preserves array/map nullability. Author: Takuya UESHIN <[email protected]> Author: Takuya UESHIN <[email protected]> Closes #13873 from ueshin/issues/SPARK-16167.
1 parent 4852b7d commit 873f3ad

File tree

2 files changed

+54
-4
lines changed

2 files changed

+54
-4
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ object RowEncoder {
123123
inputObject :: Nil,
124124
returnNullable = false)
125125

126-
case t @ ArrayType(et, cn) =>
126+
case t @ ArrayType(et, containsNull) =>
127127
et match {
128128
case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType =>
129129
StaticInvoke(
@@ -132,8 +132,16 @@ object RowEncoder {
132132
"toArrayData",
133133
inputObject :: Nil,
134134
returnNullable = false)
135+
135136
case _ => MapObjects(
136-
element => serializerFor(ValidateExternalType(element, et), et),
137+
element => {
138+
val value = serializerFor(ValidateExternalType(element, et), et)
139+
if (!containsNull) {
140+
AssertNotNull(value, Seq.empty)
141+
} else {
142+
value
143+
}
144+
},
137145
inputObject,
138146
ObjectType(classOf[Object]))
139147
}
@@ -155,10 +163,19 @@ object RowEncoder {
155163
ObjectType(classOf[scala.collection.Seq[_]]), returnNullable = false)
156164
val convertedValues = serializerFor(values, ArrayType(vt, valueNullable))
157165

158-
NewInstance(
166+
val nonNullOutput = NewInstance(
159167
classOf[ArrayBasedMapData],
160168
convertedKeys :: convertedValues :: Nil,
161-
dataType = t)
169+
dataType = t,
170+
propagateNull = false)
171+
172+
if (inputObject.nullable) {
173+
If(IsNull(inputObject),
174+
Literal.create(null, inputType),
175+
nonNullOutput)
176+
} else {
177+
nonNullOutput
178+
}
162179

163180
case StructType(fields) =>
164181
val nonNullOutput = CreateNamedStruct(fields.zipWithIndex.flatMap { case (field, index) =>

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,39 @@ class RowEncoderSuite extends SparkFunSuite {
273273
assert(e4.getMessage.contains("java.lang.String is not a valid external type"))
274274
}
275275

276+
for {
277+
elementType <- Seq(IntegerType, StringType)
278+
containsNull <- Seq(true, false)
279+
nullable <- Seq(true, false)
280+
} {
281+
test("RowEncoder should preserve array nullability: " +
282+
s"ArrayType($elementType, containsNull = $containsNull), nullable = $nullable") {
283+
val schema = new StructType().add("array", ArrayType(elementType, containsNull), nullable)
284+
val encoder = RowEncoder(schema).resolveAndBind()
285+
assert(encoder.serializer.length == 1)
286+
assert(encoder.serializer.head.dataType == ArrayType(elementType, containsNull))
287+
assert(encoder.serializer.head.nullable == nullable)
288+
}
289+
}
290+
291+
for {
292+
keyType <- Seq(IntegerType, StringType)
293+
valueType <- Seq(IntegerType, StringType)
294+
valueContainsNull <- Seq(true, false)
295+
nullable <- Seq(true, false)
296+
} {
297+
test("RowEncoder should preserve map nullability: " +
298+
s"MapType($keyType, $valueType, valueContainsNull = $valueContainsNull), " +
299+
s"nullable = $nullable") {
300+
val schema = new StructType().add(
301+
"map", MapType(keyType, valueType, valueContainsNull), nullable)
302+
val encoder = RowEncoder(schema).resolveAndBind()
303+
assert(encoder.serializer.length == 1)
304+
assert(encoder.serializer.head.dataType == MapType(keyType, valueType, valueContainsNull))
305+
assert(encoder.serializer.head.nullable == nullable)
306+
}
307+
}
308+
276309
private def encodeDecodeTest(schema: StructType): Unit = {
277310
test(s"encode/decode: ${schema.simpleString}") {
278311
val encoder = RowEncoder(schema).resolveAndBind()

0 commit comments

Comments
 (0)