Skip to content

Commit aa07e56

Browse files
committed
nullability of array type element should not fail analysis of encoder
1 parent b1835d7 commit aa07e56

File tree

7 files changed

+68
-109
lines changed

7 files changed

+68
-109
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ object JavaTypeInference {
292292
val setter = if (nullable) {
293293
constructor
294294
} else {
295-
AssertNotNull(constructor, other.getName, fieldName, fieldType.toString)
295+
AssertNotNull(constructor, Seq("currently no type path record in java"))
296296
}
297297
p.getWriteMethod.getName -> setter
298298
}.toMap

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,8 @@ object ScalaReflection extends ScalaReflection {
249249

250250
case t if t <:< localTypeOf[Array[_]] =>
251251
val TypeRef(_, _, Seq(elementType)) = t
252+
253+
// TODO: add runtime null check for primitive array
252254
val primitiveMethod = elementType match {
253255
case t if t <:< definitions.IntTpe => Some("toIntArray")
254256
case t if t <:< definitions.LongTpe => Some("toLongArray")
@@ -276,22 +278,29 @@ object ScalaReflection extends ScalaReflection {
276278

277279
case t if t <:< localTypeOf[Seq[_]] =>
278280
val TypeRef(_, _, Seq(elementType)) = t
281+
val Schema(dataType, nullable) = schemaFor(elementType)
279282
val className = getClassNameFromType(elementType)
280283
val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath
281-
val arrayData =
282-
Invoke(
283-
MapObjects(
284-
p => constructorFor(elementType, Some(p), newTypePath),
285-
getPath,
286-
schemaFor(elementType).dataType),
287-
"array",
288-
ObjectType(classOf[Array[Any]]))
284+
285+
val mapFunction: Expression => Expression = p => {
286+
val converter = constructorFor(elementType, Some(p), newTypePath)
287+
if (nullable) {
288+
converter
289+
} else {
290+
AssertNotNull(converter, newTypePath)
291+
}
292+
}
293+
294+
val array = Invoke(
295+
MapObjects(mapFunction, getPath, dataType),
296+
"array",
297+
ObjectType(classOf[Array[Any]]))
289298

290299
StaticInvoke(
291300
scala.collection.mutable.WrappedArray.getClass,
292301
ObjectType(classOf[Seq[_]]),
293302
"make",
294-
arrayData :: Nil)
303+
array :: Nil)
295304

296305
case t if t <:< localTypeOf[Map[_, _]] =>
297306
// TODO: add walked type path for map
@@ -343,7 +352,7 @@ object ScalaReflection extends ScalaReflection {
343352
newTypePath)
344353

345354
if (!nullable) {
346-
AssertNotNull(constructor, t.toString, fieldName, fieldType.toString)
355+
AssertNotNull(constructor, newTypePath)
347356
} else {
348357
constructor
349358
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1346,7 +1346,7 @@ object ResolveUpCast extends Rule[LogicalPlan] {
13461346
fail(child, DateType, walkedTypePath)
13471347
case (StringType, to: NumericType) =>
13481348
fail(child, to, walkedTypePath)
1349-
case _ => Cast(child, dataType)
1349+
case _ => Cast(child, dataType.asNullable)
13501350
}
13511351
}
13521352
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ object MapObjects {
365365
* to handle collection elements.
366366
* @param inputData An expression that when evaluted returns a collection object.
367367
*/
368-
case class MapObjects(
368+
case class MapObjects private(
369369
loopVar: LambdaVariable,
370370
lambdaFunction: Expression,
371371
inputData: Expression) extends Expression {
@@ -637,8 +637,7 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp
637637
* `Int` field named `i`. Expression `s.i` is nullable because `s` can be null. However, for all
638638
* non-null `s`, `s.i` can't be null.
639639
*/
640-
case class AssertNotNull(
641-
child: Expression, parentType: String, fieldName: String, fieldType: String)
640+
case class AssertNotNull(child: Expression, walkedTypePath: Seq[String])
642641
extends UnaryExpression {
643642

644643
override def dataType: DataType = child.dataType
@@ -651,19 +650,22 @@ case class AssertNotNull(
651650
override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = {
652651
val childGen = child.gen(ctx)
653652

653+
val errMsg = "Null value appeared in non-nullable field:" +
654+
walkedTypePath.mkString("\n", "\n", "\n") +
655+
"If the schema is inferred from a Scala tuple/case class, or a Java bean, " +
656+
"please try to use scala.Option[_] or other nullable types " +
657+
"(e.g. java.lang.Integer instead of int/scala.Int)."
658+
val idx = ctx.references.length
659+
ctx.references += errMsg
660+
654661
ev.isNull = "false"
655662
ev.value = childGen.value
656663

657664
s"""
658665
${childGen.code}
659666

660667
if (${childGen.isNull}) {
661-
throw new RuntimeException(
662-
"Null value appeared in non-nullable field $parentType.$fieldName of type $fieldType. " +
663-
"If the schema is inferred from a Scala tuple/case class, or a Java bean, " +
664-
"please try to use scala.Option[_] or other nullable types " +
665-
"(e.g. java.lang.Integer instead of int/scala.Int)."
666-
);
668+
throw new RuntimeException((String) references[$idx]);
667669
}
668670
"""
669671
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala

Lines changed: 30 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,11 @@ import scala.reflect.runtime.universe.TypeTag
2121

2222
import org.apache.spark.sql.AnalysisException
2323
import org.apache.spark.sql.catalyst.dsl.expressions._
24-
import org.apache.spark.sql.catalyst.expressions._
2524
import org.apache.spark.sql.catalyst.plans.PlanTest
25+
import org.apache.spark.sql.catalyst.util.GenericArrayData
26+
import org.apache.spark.sql.catalyst.InternalRow
2627
import org.apache.spark.sql.types._
28+
import org.apache.spark.unsafe.types.UTF8String
2729

2830
case class StringLongClass(a: String, b: Long)
2931

@@ -32,94 +34,49 @@ case class StringIntClass(a: String, b: Int)
3234
case class ComplexClass(a: Long, b: StringLongClass)
3335

3436
class EncoderResolutionSuite extends PlanTest {
37+
private val str = UTF8String.fromString("hello")
38+
3539
test("real type doesn't match encoder schema but they are compatible: product") {
3640
val encoder = ExpressionEncoder[StringLongClass]
37-
val cls = classOf[StringLongClass]
38-
3941

40-
{
41-
val attrs = Seq('a.string, 'b.int)
42-
val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression
43-
val expected: Expression = NewInstance(
44-
cls,
45-
Seq(
46-
toExternalString('a.string),
47-
AssertNotNull('b.int.cast(LongType), cls.getName, "b", "Long")
48-
),
49-
ObjectType(cls),
50-
propagateNull = false)
51-
compareExpressions(fromRowExpr, expected)
52-
}
42+
// int type can be up cast to long type
43+
val attrs1 = Seq('a.string, 'b.int)
44+
encoder.resolve(attrs1, null).bind(attrs1).fromRow(InternalRow(str, 1))
5345

54-
{
55-
val attrs = Seq('a.int, 'b.long)
56-
val fromRowExpr = encoder.resolve(attrs, null).fromRowExpression
57-
val expected = NewInstance(
58-
cls,
59-
Seq(
60-
toExternalString('a.int.cast(StringType)),
61-
AssertNotNull('b.long, cls.getName, "b", "Long")
62-
),
63-
ObjectType(cls),
64-
propagateNull = false)
65-
compareExpressions(fromRowExpr, expected)
66-
}
46+
// int type can be up cast to string type
47+
val attrs2 = Seq('a.int, 'b.long)
48+
encoder.resolve(attrs2, null).bind(attrs2).fromRow(InternalRow(1, 2L))
6749
}
6850

6951
test("real type doesn't match encoder schema but they are compatible: nested product") {
7052
val encoder = ExpressionEncoder[ComplexClass]
71-
val innerCls = classOf[StringLongClass]
72-
val cls = classOf[ComplexClass]
73-
7453
val attrs = Seq('a.int, 'b.struct('a.int, 'b.long))
75-
val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression
76-
val expected: Expression = NewInstance(
77-
cls,
78-
Seq(
79-
AssertNotNull('a.int.cast(LongType), cls.getName, "a", "Long"),
80-
If(
81-
'b.struct('a.int, 'b.long).isNull,
82-
Literal.create(null, ObjectType(innerCls)),
83-
NewInstance(
84-
innerCls,
85-
Seq(
86-
toExternalString(
87-
GetStructField('b.struct('a.int, 'b.long), 0, Some("a")).cast(StringType)),
88-
AssertNotNull(
89-
GetStructField('b.struct('a.int, 'b.long), 1, Some("b")),
90-
innerCls.getName, "b", "Long")),
91-
ObjectType(innerCls),
92-
propagateNull = false)
93-
)),
94-
ObjectType(cls),
95-
propagateNull = false)
96-
compareExpressions(fromRowExpr, expected)
54+
encoder.resolve(attrs, null).bind(attrs).fromRow(InternalRow(1, InternalRow(2, 3L)))
9755
}
9856

9957
test("real type doesn't match encoder schema but they are compatible: tupled encoder") {
10058
val encoder = ExpressionEncoder.tuple(
10159
ExpressionEncoder[StringLongClass],
10260
ExpressionEncoder[Long])
103-
val cls = classOf[StringLongClass]
104-
10561
val attrs = Seq('a.struct('a.string, 'b.byte), 'b.int)
106-
val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression
107-
val expected: Expression = NewInstance(
108-
classOf[Tuple2[_, _]],
109-
Seq(
110-
NewInstance(
111-
cls,
112-
Seq(
113-
toExternalString(GetStructField('a.struct('a.string, 'b.byte), 0, Some("a"))),
114-
AssertNotNull(
115-
GetStructField('a.struct('a.string, 'b.byte), 1, Some("b")).cast(LongType),
116-
cls.getName, "b", "Long")),
117-
ObjectType(cls),
118-
propagateNull = false),
119-
'b.int.cast(LongType)),
120-
ObjectType(classOf[Tuple2[_, _]]),
121-
propagateNull = false)
122-
compareExpressions(fromRowExpr, expected)
62+
encoder.resolve(attrs, null).bind(attrs).fromRow(InternalRow(InternalRow(str, 1.toByte), 2))
63+
}
64+
65+
test("nullability of array type element should not fail analysis") {
66+
val encoder = ExpressionEncoder[Seq[Int]]
67+
val attrs = 'a.array(IntegerType) :: Nil
68+
69+
// It should pass analysis
70+
val bound = encoder.resolve(attrs, null).bind(attrs)
71+
72+
// If no null values appear, it should works fine
73+
bound.fromRow(InternalRow(new GenericArrayData(Array(1, 2))))
74+
75+
// If there is null value, it should throw runtime exception
76+
val e = intercept[RuntimeException] {
77+
bound.fromRow(InternalRow(new GenericArrayData(Array(1, null))))
78+
}
79+
assert(e.getMessage.contains("Null value appeared in non-nullable field"))
12380
}
12481

12582
test("the real number of fields doesn't match encoder schema: tuple encoder") {
@@ -166,10 +123,6 @@ class EncoderResolutionSuite extends PlanTest {
166123
}
167124
}
168125

169-
private def toExternalString(e: Expression): Expression = {
170-
Invoke(e, "toString", ObjectType(classOf[String]), Nil)
171-
}
172-
173126
test("throw exception if real type is not compatible with encoder schema") {
174127
val msg1 = intercept[AnalysisException] {
175128
ExpressionEncoder[StringIntClass].resolve(Seq('a.string, 'b.long), null)

sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -850,9 +850,7 @@ public void testRuntimeNullabilityCheck() {
850850
}
851851

852852
nullabilityCheck.expect(RuntimeException.class);
853-
nullabilityCheck.expectMessage(
854-
"Null value appeared in non-nullable field " +
855-
"test.org.apache.spark.sql.JavaDatasetSuite$SmallBean.b of type int.");
853+
nullabilityCheck.expectMessage("Null value appeared in non-nullable field");
856854

857855
{
858856
Row row = new GenericRow(new Object[] {

sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,12 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
4545
1, 1, 1)
4646
}
4747

48-
4948
test("SPARK-12404: Datatype Helper Serializablity") {
5049
val ds = sparkContext.parallelize((
51-
new Timestamp(0),
52-
new Date(0),
53-
java.math.BigDecimal.valueOf(1),
54-
scala.math.BigDecimal(1)) :: Nil).toDS()
50+
new Timestamp(0),
51+
new Date(0),
52+
java.math.BigDecimal.valueOf(1),
53+
scala.math.BigDecimal(1)) :: Nil).toDS()
5554

5655
ds.collect()
5756
}
@@ -553,9 +552,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
553552
buildDataset(Row(Row("hello", null))).collect()
554553
}.getMessage
555554

556-
assert(message.contains(
557-
"Null value appeared in non-nullable field org.apache.spark.sql.ClassData.b of type Int."
558-
))
555+
assert(message.contains("Null value appeared in non-nullable field"))
559556
}
560557

561558
test("SPARK-12478: top level null field") {

0 commit comments

Comments
 (0)