Skip to content

Commit d202ad2

Browse files
viiryamarmbrus
authored andcommitted
[SPARK-12439][SQL] Fix toCatalystArray and MapObjects
JIRA: https://issues.apache.org/jira/browse/SPARK-12439 In toCatalystArray, we should look at the data type returned by dataTypeFor instead of silentSchemaFor, to determine if the element is native type. An obvious problem is when the element is Option[Int] class, catalsilentSchemaFor will return Int, then we will wrongly recognize the element is native type. There is another problem when using Option as array element. When we encode data like Seq(Some(1), Some(2), None) with encoder, we will use MapObjects to construct an array for it later. But in MapObjects, we don't check if the return value of lambdaFunction is null or not. That causes a bug that the decoded data for Seq(Some(1), Some(2), None) would be Seq(1, 2, -1), instead of Seq(1, 2, null). Author: Liang-Chi Hsieh <[email protected]> Closes #10391 from viirya/fix-catalystarray.
1 parent 8ce645d commit d202ad2

File tree

4 files changed

+14
-6
lines changed

4 files changed

+14
-6
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ object ScalaReflection extends ScalaReflection {
405405
def toCatalystArray(input: Expression, elementType: `Type`): Expression = {
406406
val externalDataType = dataTypeFor(elementType)
407407
val Schema(catalystType, nullable) = silentSchemaFor(elementType)
408-
if (isNativeType(catalystType)) {
408+
if (isNativeType(externalDataType)) {
409409
NewInstance(
410410
classOf[GenericArrayData],
411411
input :: Nil,

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ object RowEncoder {
3535
def apply(schema: StructType): ExpressionEncoder[Row] = {
3636
val cls = classOf[Row]
3737
val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
38-
val extractExpressions = extractorsFor(inputObject, schema)
38+
// We use an If expression to wrap extractorsFor result of StructType
39+
val extractExpressions = extractorsFor(inputObject, schema).asInstanceOf[If].falseValue
3940
val constructExpression = constructorFor(schema)
4041
new ExpressionEncoder[Row](
4142
schema,
@@ -129,7 +130,9 @@ object RowEncoder {
129130
Invoke(inputObject, method, externalDataTypeFor(f.dataType), Literal(i) :: Nil),
130131
f.dataType))
131132
}
132-
CreateStruct(convertedFields)
133+
If(IsNull(inputObject),
134+
Literal.create(null, inputType),
135+
CreateStruct(convertedFields))
133136
}
134137

135138
private def externalDataTypeFor(dt: DataType): DataType = dt match {
@@ -220,6 +223,8 @@ object RowEncoder {
220223
Literal.create(null, externalDataTypeFor(f.dataType)),
221224
constructorFor(GetStructField(input, i)))
222225
}
223-
CreateExternalRow(convertedFields)
226+
If(IsNull(input),
227+
Literal.create(null, externalDataTypeFor(input.dataType)),
228+
CreateExternalRow(convertedFields))
224229
}
225230
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -456,10 +456,10 @@ case class MapObjects(
456456
($elementJavaType)${genInputData.value}${itemAccessor(loopIndex)};
457457
$loopNullCheck
458458

459-
if (${loopVar.isNull}) {
459+
${genFunction.code}
460+
if (${genFunction.isNull}) {
460461
$convertedArray[$loopIndex] = null;
461462
} else {
462-
${genFunction.code}
463463
$convertedArray[$loopIndex] = ${genFunction.value};
464464
}
465465

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,9 @@ class ExpressionEncoderSuite extends SparkFunSuite {
160160

161161
productTest(OptionalData(None, None, None, None, None, None, None, None))
162162

163+
encodeDecodeTest(Seq(Some(1), None), "Option in array")
164+
encodeDecodeTest(Map(1 -> Some(10L), 2 -> Some(20L), 3 -> None), "Option in map")
165+
163166
productTest(BoxedData(1, 1L, 1.0, 1.0f, 1.toShort, 1.toByte, true))
164167

165168
productTest(BoxedData(null, null, null, null, null, null, null))

0 commit comments

Comments
 (0)