Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.{InternalRow, JavaTypeInference, ScalaRefle
import org.apache.spark.sql.catalyst.analysis.{Analyzer, GetColumnByOrdinal, SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, NewInstance}
import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, Invoke, NewInstance}
import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts
import org.apache.spark.sql.catalyst.plans.logical.{CatalystSerde, DeserializeToObject, LocalRelation}
import org.apache.spark.sql.types.{ObjectType, StructField, StructType}
Expand All @@ -50,8 +50,15 @@ object ExpressionEncoder {
val cls = mirror.runtimeClass(tpe)
val flat = !ScalaReflection.definedByConstructorParams(tpe)

val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = false)
val serializer = ScalaReflection.serializerFor[T](inputObject)
val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = true)
val nullSafeInput = if (flat) {
inputObject
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to also disallow null input object for flat type, but failed as some other tests already depend on this feature. e.g. ParquetIOSuite.null and non-null strings

} else {
// For input object of non-flat type, we can't encode it to row if it's null, as Spark SQL
// doesn't allow top-level row to be null, only its columns can be null.
AssertNotNull(inputObject, Seq("top level non-flat input object"))
}
val serializer = ScalaReflection.serializerFor[T](nullSafeInput)
val deserializer = ScalaReflection.deserializerFor[T]

val schema = ScalaReflection.schemaFor[T] match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ import org.apache.spark.unsafe.types.UTF8String
object RowEncoder {
def apply(schema: StructType): ExpressionEncoder[Row] = {
val cls = classOf[Row]
val inputObject = BoundReference(0, ObjectType(cls), nullable = false)
val serializer = serializerFor(inputObject, schema)
val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
val serializer = serializerFor(AssertNotNull(inputObject, Seq("top level row object")), schema)
val deserializer = deserializerFor(schema)
new ExpressionEncoder[Row](
schema,
Expand Down Expand Up @@ -153,8 +153,7 @@ object RowEncoder {
val fieldValue = serializerFor(
GetExternalRowField(
inputObject, index, field.name, externalDataTypeForInput(field.dataType)),
field.dataType
)
field.dataType)
val convertedField = if (field.nullable) {
If(
Invoke(inputObject, "isNullAt", BooleanType, Literal(index) :: Nil),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType)
val code = s"""
$values = new Object[${children.size}];
$childrenCode
final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, this.$schemaField);
final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, $schemaField);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Preserving this is safer (just in case name collision between class fields and local variables).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dropping this is safer, as the code maybe put in an inner class.
We generate names by a unique id and it's impossible to get name collision.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see, thanks!

"""
ev.copy(code = code, isNull = "false")
}
Expand Down Expand Up @@ -675,7 +675,7 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String])
${childGen.code}

if (${childGen.isNull}) {
throw new RuntimeException(this.$errMsgField);
throw new RuntimeException($errMsgField);
Copy link
Contributor

@liancheng liancheng Jun 3, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

}
"""
ev.copy(code = code, isNull = "false", value = childGen.value)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,14 @@ class RowEncoderSuite extends SparkFunSuite {
assert(convertedBack.getSeq(2) == Seq(Seq(Seq(0L, null), null), null))
}

test("RowEncoder should throw RuntimeException if input row object is null") {
val schema = new StructType().add("int", IntegerType)
val encoder = RowEncoder(schema)
val e = intercept[RuntimeException](encoder.toRow(null))
assert(e.getMessage.contains("Null value appeared in non-nullable field"))
assert(e.getMessage.contains("top level row object"))
}

private def encodeDecodeTest(schema: StructType): Unit = {
test(s"encode/decode: ${schema.simpleString}") {
val encoder = RowEncoder(schema).resolveAndBind()
Expand Down
10 changes: 10 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -790,6 +790,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
assert(e.getMessage.contains(
"`abstract` is a reserved keyword and cannot be used as field name"))
}

test("Dataset should support flat input object to be null") {
checkDataset(Seq("a", null).toDS(), "a", null)
}

test("Dataset should throw RuntimeException if non-flat input object is null") {
val e = intercept[RuntimeException](Seq(ClassData("a", 1), null).toDS())
assert(e.getMessage.contains("Null value appeared in non-nullable field"))
assert(e.getMessage.contains("top level non-flat input object"))
}
}

case class Generic[T](id: T, value: Double)
Expand Down