From de1eff69ab21f9b03108d46dd65189ade657ae66 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 2 Jun 2016 09:56:27 -0700 Subject: [PATCH 1/2] make the semantics of null input object for encoder clear --- .../sql/catalyst/encoders/ExpressionEncoder.scala | 13 ++++++++++--- .../spark/sql/catalyst/encoders/RowEncoder.scala | 7 +++---- .../sql/catalyst/expressions/objects/objects.scala | 4 ++-- .../sql/catalyst/encoders/RowEncoderSuite.scala | 8 ++++++++ .../scala/org/apache/spark/sql/DatasetSuite.scala | 10 ++++++++++ 5 files changed, 33 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 2296946cd7c5..817af9d4db23 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.{InternalRow, JavaTypeInference, ScalaRefle import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedAttribute, UnresolvedDeserializer, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} -import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, NewInstance} +import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, Invoke, NewInstance} import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} import org.apache.spark.sql.types.{ObjectType, StructField, StructType} @@ -52,8 +52,15 @@ object ExpressionEncoder { val cls = mirror.runtimeClass(tpe) val flat = !ScalaReflection.definedByConstructorParams(tpe) - val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = false) - val serializer = ScalaReflection.serializerFor[T](inputObject) + val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = true) + val nullSafeInput = if (flat) { + inputObject + } else { + // For input object of non-flat type, we can't encode it to row if it's null, as Spark SQL + // doesn't allow row to be null, only its columns can be null. + AssertNotNull(inputObject, Seq("top level non-flat input object")) + } + val serializer = ScalaReflection.serializerFor[T](nullSafeInput) val deserializer = ScalaReflection.deserializerFor[T] val schema = ScalaReflection.schemaFor[T] match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 0de9166aa29a..2c2361459824 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -56,8 +56,8 @@ import org.apache.spark.unsafe.types.UTF8String object RowEncoder { def apply(schema: StructType): ExpressionEncoder[Row] = { val cls = classOf[Row] - val inputObject = BoundReference(0, ObjectType(cls), nullable = false) - val serializer = serializerFor(inputObject, schema) + val inputObject = BoundReference(0, ObjectType(cls), nullable = true) + val serializer = serializerFor(AssertNotNull(inputObject, Seq("top level row object")), schema) val deserializer = deserializerFor(schema) new ExpressionEncoder[Row]( schema, @@ -152,8 +152,7 @@ object RowEncoder { val fieldValue = serializerFor( GetExternalRowField( inputObject, index, field.name, externalDataTypeForInput(field.dataType)), - field.dataType - ) + field.dataType) val convertedField = if (field.nullable) { If( Invoke(inputObject, "isNullAt", BooleanType, Literal(index) :: Nil), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index c2e3ab82ff16..d4c71bffe86b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -519,7 +519,7 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType) val code = s""" $values = new Object[${children.size}]; $childrenCode - final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, this.$schemaField); + final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, $schemaField); """ ev.copy(code = code, isNull = "false") } @@ -675,7 +675,7 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String]) ${childGen.code} if (${childGen.isNull}) { - throw new RuntimeException(this.$errMsgField); + throw new RuntimeException($errMsgField); } """ ev.copy(code = code, isNull = "false", value = childGen.value) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index 39fcc7225b37..b075e06127ee 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -224,6 +224,14 @@ class RowEncoderSuite extends SparkFunSuite { assert(convertedBack.getSeq(2) == Seq(Seq(Seq(0L, null), null), null)) } + test("RowEncoder should throw RuntimeException if input row object is null") { + val schema = new StructType().add("int", IntegerType) + val encoder = RowEncoder(schema) + val e = intercept[RuntimeException](encoder.toRow(null)) + assert(e.getMessage.contains("Null value appeared in non-nullable field")) + assert(e.getMessage.contains("top level row object")) + } + private def encodeDecodeTest(schema: StructType): Unit = { test(s"encode/decode: ${schema.simpleString}") { val encoder = RowEncoder(schema) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index a3881ff92015..4d85eaf3cf84 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -786,6 +786,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val result = joined.collect().toSet assert(result == Set(ClassData("a", 1) -> null, ClassData("b", 2) -> ClassData("x", 2))) } + + test("Dataset should support flat input object to be null") { + checkDataset(Seq("a", null).toDS(), "a", null) + } + + test("Dataset should throw RuntimeException if non-flat input object is null") { + val e = intercept[RuntimeException](Seq(ClassData("a", 1), null).toDS()) + assert(e.getMessage.contains("Null value appeared in non-nullable field")) + assert(e.getMessage.contains("top level non-flat input object")) + } } case class Generic[T](id: T, value: Double) From 6b2641fc5e66e99a893108016f3c1caa1aa9e9d6 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 3 Jun 2016 10:34:19 -0700 Subject: [PATCH 2/2] address comments --- .../apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 524e2bc66855..688082dcce53 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -55,7 +55,7 @@ object ExpressionEncoder { inputObject } else { // For input object of non-flat type, we can't encode it to row if it's null, as Spark SQL - // doesn't allow row to be null, only its columns can be null. + // doesn't allow top-level row to be null, only its columns can be null. AssertNotNull(inputObject, Seq("top level non-flat input object")) } val serializer = ScalaReflection.serializerFor[T](nullSafeInput)