Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -396,12 +396,14 @@ object JavaTypeInference {

case _ if mapType.isAssignableFrom(typeToken) =>
val (keyType, valueType) = mapKeyValueType(typeToken)

ExternalMapToCatalyst(
inputObject,
ObjectType(keyType.getRawType),
serializerFor(_, keyType),
ObjectType(valueType.getRawType),
serializerFor(_, valueType)
serializerFor(_, valueType),
valueNullable = true
)

case other =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,8 @@ object ScalaReflection extends ScalaReflection {
dataTypeFor(keyType),
serializerFor(_, keyType, keyPath),
dataTypeFor(valueType),
serializerFor(_, valueType, valuePath))
serializerFor(_, valueType, valuePath),
valueNullable = !valueType.typeSymbol.asClass.isPrimitive)

case t if t <:< localTypeOf[String] =>
StaticInvoke(
Expand Down Expand Up @@ -590,7 +591,9 @@ object ScalaReflection extends ScalaReflection {
"cannot be used as field name\n" + walkedTypePath.mkString("\n"))
}

val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType))
val fieldValue = Invoke(
AssertNotNull(inputObject, walkedTypePath), fieldName, dataTypeFor(fieldType),
returnNullable = !fieldType.typeSymbol.asClass.isPrimitive)
val clsName = getClassNameFromType(fieldType)
val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath
expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType, newPath) :: Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ object ExpressionEncoder {
val cls = mirror.runtimeClass(tpe)
val flat = !ScalaReflection.definedByConstructorParams(tpe)

val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = true)
val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = !cls.isPrimitive)
Copy link
Contributor

@cloud-fan cloud-fan Nov 5, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it work for primitive types in complex types? e.g. Seq(1, 2, 3), should have data type ArrayType(IntegerType, containsNull = false)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it works. I have just added a new test for this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ping @cloud-fan

val nullSafeInput = if (flat) {
inputObject
} else {
Expand All @@ -61,10 +61,7 @@ object ExpressionEncoder {
val serializer = ScalaReflection.serializerFor[T](nullSafeInput)
val deserializer = ScalaReflection.deserializerFor[T]

val schema = ScalaReflection.schemaFor[T] match {
case ScalaReflection.Schema(s: StructType, _) => s
case ScalaReflection.Schema(dt, nullable) => new StructType().add("value", dt, nullable)
}
val schema = serializer.dataType

new ExpressionEncoder[T](
schema,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ case class ReferenceToExpressions(result: Expression, children: Seq[Expression])
ctx.addMutableState("boolean", classChildVarIsNull, "")

val classChildVar =
LambdaVariable(classChildVarName, classChildVarIsNull, child.dataType)
LambdaVariable(classChildVarName, classChildVarIsNull, child.dataType, child.nullable)

val initCode = s"${classChildVar.value} = ${childGen.value};\n" +
s"${classChildVar.isNull} = ${childGen.isNull};"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,15 +171,18 @@ case class StaticInvoke(
* @param arguments An optional list of expressions, whos evaluation will be passed to the function.
* @param propagateNull When true, and any of the arguments is null, null will be returned instead
* of calling the function.
* @param returnNullable When false, indicating the invoked method will always return
* non-null value.
*/
case class Invoke(
targetObject: Expression,
functionName: String,
dataType: DataType,
arguments: Seq[Expression] = Nil,
propagateNull: Boolean = true) extends InvokeLike {
propagateNull: Boolean = true,
returnNullable : Boolean = true) extends InvokeLike {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a @param document for returnNullable.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, done


override def nullable: Boolean = true
override def nullable: Boolean = targetObject.nullable || needNullCheck || returnNullable
override def children: Seq[Expression] = targetObject +: arguments

override def eval(input: InternalRow): Any =
Expand Down Expand Up @@ -405,13 +408,15 @@ case class WrapOption(child: Expression, optType: DataType)
* A place holder for the loop variable used in [[MapObjects]]. This should never be constructed
* manually, but will instead be passed into the provided lambda function.
*/
case class LambdaVariable(value: String, isNull: String, dataType: DataType) extends LeafExpression
case class LambdaVariable(
value: String,
isNull: String,
dataType: DataType,
nullable: Boolean = true) extends LeafExpression
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: code style should be

case class xxx(
  param1: xxx
  param2: xxx) extends ...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

with Unevaluable with NonSQLExpression {

override def nullable: Boolean = true

override def genCode(ctx: CodegenContext): ExprCode = {
ExprCode(code = "", value = value, isNull = isNull)
ExprCode(code = "", value = value, isNull = if (nullable) isNull else "false")
}
}

Expand Down Expand Up @@ -592,7 +597,8 @@ object ExternalMapToCatalyst {
keyType: DataType,
keyConverter: Expression => Expression,
valueType: DataType,
valueConverter: Expression => Expression): ExternalMapToCatalyst = {
valueConverter: Expression => Expression,
valueNullable: Boolean): ExternalMapToCatalyst = {
val id = curId.getAndIncrement()
val keyName = "ExternalMapToCatalyst_key" + id
val valueName = "ExternalMapToCatalyst_value" + id
Expand All @@ -601,11 +607,11 @@ object ExternalMapToCatalyst {
ExternalMapToCatalyst(
keyName,
keyType,
keyConverter(LambdaVariable(keyName, "false", keyType)),
keyConverter(LambdaVariable(keyName, "false", keyType, false)),
valueName,
valueIsNull,
valueType,
valueConverter(LambdaVariable(valueName, valueIsNull, valueType)),
valueConverter(LambdaVariable(valueName, valueIsNull, valueType, valueNullable)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we update the LambdaVariable for key too? It's always non-nullable

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, done

inputMap
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import java.util.Arrays
import scala.collection.mutable.ArrayBuffer
import scala.reflect.runtime.universe.TypeTag

import org.apache.spark.sql.Encoders
import org.apache.spark.sql.{Encoder, Encoders}
import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData}
import org.apache.spark.sql.catalyst.analysis.AnalysisTest
import org.apache.spark.sql.catalyst.dsl.plans._
Expand Down Expand Up @@ -300,6 +300,11 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest {
encodeDecodeTest(
ReferenceValueClass(ReferenceValueClass.Container(1)), "reference value class")

encodeDecodeTest(Option(31), "option of int")
encodeDecodeTest(Option.empty[Int], "empty option of int")
encodeDecodeTest(Option("abc"), "option of string")
encodeDecodeTest(Option.empty[String], "empty option of string")

productTest(("UDT", new ExamplePoint(0.1, 0.2)))

test("nullable of encoder schema") {
Expand Down Expand Up @@ -338,6 +343,18 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest {
}
}

test("nullable of encoder serializer") {
def checkNullable[T: Encoder](nullable: Boolean): Unit = {
assert(encoderFor[T].serializer.forall(_.nullable === nullable))
}

// test for flat encoders
checkNullable[Int](false)
checkNullable[Option[Int]](true)
checkNullable[java.lang.Integer](true)
checkNullable[String](true)
}

test("null check for map key") {
val encoder = ExpressionEncoder[Map[String, Int]]()
val e = intercept[RuntimeException](encoder.toRow(Map(("a", 1), (null, 2))))
Expand Down
52 changes: 51 additions & 1 deletion sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.sql.types._

case class TestDataPoint(x: Int, y: Double, s: String, t: TestDataPoint2)
case class TestDataPoint2(x: Int, s: String)

class DatasetSuite extends QueryTest with SharedSQLContext {
import testImplicits._
Expand Down Expand Up @@ -969,6 +972,53 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
assert(dataset.collect() sameElements Array(resultValue, resultValue))
}

test("SPARK-18284: Serializer should have correct nullable value") {
val df1 = Seq(1, 2, 3, 4).toDF
assert(df1.schema(0).nullable == false)
val df2 = Seq(Integer.valueOf(1), Integer.valueOf(2)).toDF
assert(df2.schema(0).nullable == true)

val df3 = Seq(Seq(1, 2), Seq(3, 4)).toDF
assert(df3.schema(0).nullable == true)
assert(df3.schema(0).dataType.asInstanceOf[ArrayType].containsNull == false)
val df4 = Seq(Seq("a", "b"), Seq("c", "d")).toDF
assert(df4.schema(0).nullable == true)
assert(df4.schema(0).dataType.asInstanceOf[ArrayType].containsNull == true)

val df5 = Seq((0, 1.0), (2, 2.0)).toDF("id", "v")
assert(df5.schema(0).nullable == false)
assert(df5.schema(1).nullable == false)
val df6 = Seq((0, 1.0, "a"), (2, 2.0, "b")).toDF("id", "v1", "v2")
assert(df6.schema(0).nullable == false)
assert(df6.schema(1).nullable == false)
assert(df6.schema(2).nullable == true)

val df7 = (Tuple1(Array(1, 2, 3)) :: Nil).toDF("a")
assert(df7.schema(0).nullable == true)
assert(df7.schema(0).dataType.asInstanceOf[ArrayType].containsNull == false)

val df8 = (Tuple1(Array((null: Integer), (null: Integer))) :: Nil).toDF("a")
assert(df8.schema(0).nullable == true)
assert(df8.schema(0).dataType.asInstanceOf[ArrayType].containsNull == true)

val df9 = (Tuple1(Map(2 -> 3)) :: Nil).toDF("m")
assert(df9.schema(0).nullable == true)
assert(df9.schema(0).dataType.asInstanceOf[MapType].valueContainsNull == false)

val df10 = (Tuple1(Map(1 -> (null: Integer))) :: Nil).toDF("m")
assert(df10.schema(0).nullable == true)
assert(df10.schema(0).dataType.asInstanceOf[MapType].valueContainsNull == true)

val df11 = Seq(TestDataPoint(1, 2.2, "a", null),
TestDataPoint(3, 4.4, "null", (TestDataPoint2(33, "b")))).toDF
assert(df11.schema(0).nullable == false)
assert(df11.schema(1).nullable == false)
assert(df11.schema(2).nullable == true)
assert(df11.schema(3).nullable == true)
assert(df11.schema(3).dataType.asInstanceOf[StructType].fields(0).nullable == false)
assert(df11.schema(3).dataType.asInstanceOf[StructType].fields(1).nullable == true)
}

Seq(true, false).foreach { eager =>
def testCheckpointing(testName: String)(f: => Unit): Unit = {
test(s"Dataset.checkpoint() - $testName (eager = $eager)") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ class FileStreamSinkSuite extends StreamTest {

val outputDf = spark.read.parquet(outputDir)
val expectedSchema = new StructType()
.add(StructField("value", IntegerType))
.add(StructField("value", IntegerType, nullable = false))
.add(StructField("id", IntegerType))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, do you know why id is not nullable == false?
Looks both value and id are nullable == false.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, this code makes 'id' nullable. The column specified by partitionBy() will have nullable=true.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thanks.

assert(outputDf.schema === expectedSchema)

Expand Down