From aa07e56408e95559cf04cc67a9cc9246c5e4f43c Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 2 Feb 2016 15:31:54 -0800 Subject: [PATCH] nullability of array type element should not fail analysis of encoder --- .../sql/catalyst/JavaTypeInference.scala | 2 +- .../spark/sql/catalyst/ScalaReflection.scala | 29 +++-- .../sql/catalyst/analysis/Analyzer.scala | 2 +- .../sql/catalyst/expressions/objects.scala | 20 ++-- .../encoders/EncoderResolutionSuite.scala | 107 +++++------------- .../apache/spark/sql/JavaDatasetSuite.java | 4 +- .../org/apache/spark/sql/DatasetSuite.scala | 13 +-- 7 files changed, 68 insertions(+), 109 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 3c3717d5043aa..59ee41d02f198 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -292,7 +292,7 @@ object JavaTypeInference { val setter = if (nullable) { constructor } else { - AssertNotNull(constructor, other.getName, fieldName, fieldType.toString) + AssertNotNull(constructor, Seq("currently no type path record in java")) } p.getWriteMethod.getName -> setter }.toMap diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index e5811efb436a6..02cb2d9a2b118 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -249,6 +249,8 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[Array[_]] => val TypeRef(_, _, Seq(elementType)) = t + + // TODO: add runtime null check for primitive array val primitiveMethod = elementType match { case t if t <:< definitions.IntTpe => Some("toIntArray") case t if t <:< definitions.LongTpe => Some("toLongArray") @@ -276,22 +278,29 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[Seq[_]] => val TypeRef(_, _, Seq(elementType)) = t + val Schema(dataType, nullable) = schemaFor(elementType) val className = getClassNameFromType(elementType) val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath - val arrayData = - Invoke( - MapObjects( - p => constructorFor(elementType, Some(p), newTypePath), - getPath, - schemaFor(elementType).dataType), - "array", - ObjectType(classOf[Array[Any]])) + + val mapFunction: Expression => Expression = p => { + val converter = constructorFor(elementType, Some(p), newTypePath) + if (nullable) { + converter + } else { + AssertNotNull(converter, newTypePath) + } + } + + val array = Invoke( + MapObjects(mapFunction, getPath, dataType), + "array", + ObjectType(classOf[Array[Any]])) StaticInvoke( scala.collection.mutable.WrappedArray.getClass, ObjectType(classOf[Seq[_]]), "make", - arrayData :: Nil) + array :: Nil) case t if t <:< localTypeOf[Map[_, _]] => // TODO: add walked type path for map @@ -343,7 +352,7 @@ object ScalaReflection extends ScalaReflection { newTypePath) if (!nullable) { - AssertNotNull(constructor, t.toString, fieldName, fieldType.toString) + AssertNotNull(constructor, newTypePath) } else { constructor } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index a983dc1cdfebe..375896cd5a7d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1346,7 +1346,7 @@ object ResolveUpCast extends Rule[LogicalPlan] { fail(child, DateType, walkedTypePath) case (StringType, to: NumericType) => fail(child, to, walkedTypePath) - case _ => Cast(child, dataType) + case _ => Cast(child, dataType.asNullable) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 79fe0033b71ab..fef6825b2db5e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -365,7 +365,7 @@ object MapObjects { * to handle collection elements. * @param inputData An expression that when evaluted returns a collection object. */ -case class MapObjects( +case class MapObjects private( loopVar: LambdaVariable, lambdaFunction: Expression, inputData: Expression) extends Expression { @@ -637,8 +637,7 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp * `Int` field named `i`. Expression `s.i` is nullable because `s` can be null. However, for all * non-null `s`, `s.i` can't be null. */ -case class AssertNotNull( - child: Expression, parentType: String, fieldName: String, fieldType: String) +case class AssertNotNull(child: Expression, walkedTypePath: Seq[String]) extends UnaryExpression { override def dataType: DataType = child.dataType @@ -651,6 +650,14 @@ case class AssertNotNull( override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { val childGen = child.gen(ctx) + val errMsg = "Null value appeared in non-nullable field:" + + walkedTypePath.mkString("\n", "\n", "\n") + + "If the schema is inferred from a Scala tuple/case class, or a Java bean, " + + "please try to use scala.Option[_] or other nullable types " + + "(e.g. java.lang.Integer instead of int/scala.Int)." + val idx = ctx.references.length + ctx.references += errMsg + ev.isNull = "false" ev.value = childGen.value @@ -658,12 +665,7 @@ case class AssertNotNull( ${childGen.code} if (${childGen.isNull}) { - throw new RuntimeException( - "Null value appeared in non-nullable field $parentType.$fieldName of type $fieldType. " + - "If the schema is inferred from a Scala tuple/case class, or a Java bean, " + - "please try to use scala.Option[_] or other nullable types " + - "(e.g. java.lang.Integer instead of int/scala.Int)." - ); + throw new RuntimeException((String) references[$idx]); } """ } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala index bc36a55ae0ea2..1d7a708cdc2ab 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala @@ -21,9 +21,11 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String case class StringLongClass(a: String, b: Long) @@ -32,94 +34,49 @@ case class StringIntClass(a: String, b: Int) case class ComplexClass(a: Long, b: StringLongClass) class EncoderResolutionSuite extends PlanTest { + private val str = UTF8String.fromString("hello") + test("real type doesn't match encoder schema but they are compatible: product") { val encoder = ExpressionEncoder[StringLongClass] - val cls = classOf[StringLongClass] - - { - val attrs = Seq('a.string, 'b.int) - val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression - val expected: Expression = NewInstance( - cls, - Seq( - toExternalString('a.string), - AssertNotNull('b.int.cast(LongType), cls.getName, "b", "Long") - ), - ObjectType(cls), - propagateNull = false) - compareExpressions(fromRowExpr, expected) - } + // int type can be up cast to long type + val attrs1 = Seq('a.string, 'b.int) + encoder.resolve(attrs1, null).bind(attrs1).fromRow(InternalRow(str, 1)) - { - val attrs = Seq('a.int, 'b.long) - val fromRowExpr = encoder.resolve(attrs, null).fromRowExpression - val expected = NewInstance( - cls, - Seq( - toExternalString('a.int.cast(StringType)), - AssertNotNull('b.long, cls.getName, "b", "Long") - ), - ObjectType(cls), - propagateNull = false) - compareExpressions(fromRowExpr, expected) - } + // int type can be up cast to string type + val attrs2 = Seq('a.int, 'b.long) + encoder.resolve(attrs2, null).bind(attrs2).fromRow(InternalRow(1, 2L)) } test("real type doesn't match encoder schema but they are compatible: nested product") { val encoder = ExpressionEncoder[ComplexClass] - val innerCls = classOf[StringLongClass] - val cls = classOf[ComplexClass] - val attrs = Seq('a.int, 'b.struct('a.int, 'b.long)) - val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression - val expected: Expression = NewInstance( - cls, - Seq( - AssertNotNull('a.int.cast(LongType), cls.getName, "a", "Long"), - If( - 'b.struct('a.int, 'b.long).isNull, - Literal.create(null, ObjectType(innerCls)), - NewInstance( - innerCls, - Seq( - toExternalString( - GetStructField('b.struct('a.int, 'b.long), 0, Some("a")).cast(StringType)), - AssertNotNull( - GetStructField('b.struct('a.int, 'b.long), 1, Some("b")), - innerCls.getName, "b", "Long")), - ObjectType(innerCls), - propagateNull = false) - )), - ObjectType(cls), - propagateNull = false) - compareExpressions(fromRowExpr, expected) + encoder.resolve(attrs, null).bind(attrs).fromRow(InternalRow(1, InternalRow(2, 3L))) } test("real type doesn't match encoder schema but they are compatible: tupled encoder") { val encoder = ExpressionEncoder.tuple( ExpressionEncoder[StringLongClass], ExpressionEncoder[Long]) - val cls = classOf[StringLongClass] - val attrs = Seq('a.struct('a.string, 'b.byte), 'b.int) - val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression - val expected: Expression = NewInstance( - classOf[Tuple2[_, _]], - Seq( - NewInstance( - cls, - Seq( - toExternalString(GetStructField('a.struct('a.string, 'b.byte), 0, Some("a"))), - AssertNotNull( - GetStructField('a.struct('a.string, 'b.byte), 1, Some("b")).cast(LongType), - cls.getName, "b", "Long")), - ObjectType(cls), - propagateNull = false), - 'b.int.cast(LongType)), - ObjectType(classOf[Tuple2[_, _]]), - propagateNull = false) - compareExpressions(fromRowExpr, expected) + encoder.resolve(attrs, null).bind(attrs).fromRow(InternalRow(InternalRow(str, 1.toByte), 2)) + } + + test("nullability of array type element should not fail analysis") { + val encoder = ExpressionEncoder[Seq[Int]] + val attrs = 'a.array(IntegerType) :: Nil + + // It should pass analysis + val bound = encoder.resolve(attrs, null).bind(attrs) + + // If no null values appear, it should works fine + bound.fromRow(InternalRow(new GenericArrayData(Array(1, 2)))) + + // If there is null value, it should throw runtime exception + val e = intercept[RuntimeException] { + bound.fromRow(InternalRow(new GenericArrayData(Array(1, null)))) + } + assert(e.getMessage.contains("Null value appeared in non-nullable field")) } test("the real number of fields doesn't match encoder schema: tuple encoder") { @@ -166,10 +123,6 @@ class EncoderResolutionSuite extends PlanTest { } } - private def toExternalString(e: Expression): Expression = { - Invoke(e, "toString", ObjectType(classOf[String]), Nil) - } - test("throw exception if real type is not compatible with encoder schema") { val msg1 = intercept[AnalysisException] { ExpressionEncoder[StringIntClass].resolve(Seq('a.string, 'b.long), null) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index a6fb62c17d59b..1181244c8a4ed 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -850,9 +850,7 @@ public void testRuntimeNullabilityCheck() { } nullabilityCheck.expect(RuntimeException.class); - nullabilityCheck.expectMessage( - "Null value appeared in non-nullable field " + - "test.org.apache.spark.sql.JavaDatasetSuite$SmallBean.b of type int."); + nullabilityCheck.expectMessage("Null value appeared in non-nullable field"); { Row row = new GenericRow(new Object[] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index b69bb21db532b..cbebbc187b26f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -45,13 +45,12 @@ class DatasetSuite extends QueryTest with SharedSQLContext { 1, 1, 1) } - test("SPARK-12404: Datatype Helper Serializablity") { val ds = sparkContext.parallelize(( - new Timestamp(0), - new Date(0), - java.math.BigDecimal.valueOf(1), - scala.math.BigDecimal(1)) :: Nil).toDS() + new Timestamp(0), + new Date(0), + java.math.BigDecimal.valueOf(1), + scala.math.BigDecimal(1)) :: Nil).toDS() ds.collect() } @@ -553,9 +552,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { buildDataset(Row(Row("hello", null))).collect() }.getMessage - assert(message.contains( - "Null value appeared in non-nullable field org.apache.spark.sql.ClassData.b of type Int." - )) + assert(message.contains("Null value appeared in non-nullable field")) } test("SPARK-12478: top level null field") {