From 886beb0c5de8d2a971937200bd57f1f21cee253d Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Tue, 29 Nov 2016 13:20:22 +0900 Subject: [PATCH 1/6] Add `returnNullable` to `StaticInvoke` and modify it to handle properly. --- .../expressions/objects/objects.scala | 41 +++++++++++++++---- 1 file changed, 32 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index e517ec18eb54..6bf8542c0064 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -116,17 +116,19 @@ trait InvokeLike extends Expression with NonSQLExpression { * @param arguments An optional list of expressions to pass as arguments to the function. * @param propagateNull When true, and any of the arguments is null, null will be returned instead * of calling the function. + * @param returnNullable When false, return value must be non-null. */ case class StaticInvoke( staticObject: Class[_], dataType: DataType, functionName: String, arguments: Seq[Expression] = Nil, - propagateNull: Boolean = true) extends InvokeLike { + propagateNull: Boolean = true, + returnNullable: Boolean = true) extends InvokeLike { val objectName = staticObject.getName.stripSuffix("$") - override def nullable: Boolean = true + override def nullable: Boolean = needNullCheck || returnNullable override def children: Seq[Expression] = arguments override def eval(input: InternalRow): Any = @@ -139,19 +141,40 @@ case class StaticInvoke( val callFunc = s"$objectName.$functionName($argString)" - // If the function can return null, we do an extra check to make sure our null bit is still set - // correctly. - val postNullCheck = if (ctx.defaultValue(dataType) == "null") { - s"${ev.isNull} = ${ev.value} == null;" + val prepareIsNull = if (nullable) { + s"boolean ${ev.isNull} = $resultIsNull;" } else { + ev.isNull = "false" "" } + val evaluate = if (returnNullable) { + if (ctx.defaultValue(dataType) == "null") { + s""" + ${ev.value} = $callFunc; + ${ev.isNull} = ${ev.value} == null; + """ + } else { + val boxedResult = ctx.freshName("boxedResult") + s""" + ${ctx.boxedType(dataType)} $boxedResult = $callFunc; + ${ev.isNull} = $boxedResult == null; + if (!${ev.isNull}) { + ${ev.value} = $boxedResult; + } + """ + } + } else { + s"${ev.value} = $callFunc;" + } + val code = s""" $argCode - boolean ${ev.isNull} = $resultIsNull; - final $javaType ${ev.value} = $resultIsNull ? ${ctx.defaultValue(dataType)} : $callFunc; - $postNullCheck + $prepareIsNull + $javaType ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!$resultIsNull) { + $evaluate + } """ ev.copy(code = code) } From ab9d6fcdcc3ba82c0fe4e86add47355e7fe759e6 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Tue, 29 Nov 2016 14:36:35 +0900 Subject: [PATCH 2/6] Add `returnNullable` parameter to callers of `StaticInvoke`. --- .../sql/catalyst/JavaTypeInference.scala | 24 ++++++++----- .../spark/sql/catalyst/ScalaReflection.scala | 36 ++++++++++++------- .../sql/catalyst/encoders/RowEncoder.scala | 27 +++++++++----- 3 files changed, 58 insertions(+), 29 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 04f0cfce883f..59c9161a9f2e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -204,7 +204,8 @@ object JavaTypeInference { ObjectType(c), "toJavaDate", getPath :: Nil, - propagateNull = true) + propagateNull = true, + returnNullable = false) case c if c == classOf[java.sql.Timestamp] => StaticInvoke( @@ -212,7 +213,8 @@ object JavaTypeInference { ObjectType(c), "toJavaTimestamp", getPath :: Nil, - propagateNull = true) + propagateNull = true, + returnNullable = false) case c if c == classOf[java.lang.String] => Invoke(getPath, "toString", ObjectType(classOf[String])) @@ -256,7 +258,8 @@ object JavaTypeInference { "array", ObjectType(classOf[Array[Any]])) - StaticInvoke(classOf[java.util.Arrays], ObjectType(c), "asList", array :: Nil) + StaticInvoke(classOf[java.util.Arrays], ObjectType(c), "asList", array :: Nil, + returnNullable = false) case _ if mapType.isAssignableFrom(typeToken) => val (keyType, valueType) = mapKeyValueType(typeToken) @@ -285,7 +288,8 @@ object JavaTypeInference { ArrayBasedMapData.getClass, ObjectType(classOf[JMap[_, _]]), "toJavaMap", - keyData :: valueData :: Nil) + keyData :: valueData :: Nil, + returnNullable = false) case other => val properties = getJavaBeanProperties(other) @@ -350,28 +354,32 @@ object JavaTypeInference { classOf[UTF8String], StringType, "fromString", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = true) case c if c == classOf[java.sql.Timestamp] => StaticInvoke( DateTimeUtils.getClass, TimestampType, "fromJavaTimestamp", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case c if c == classOf[java.sql.Date] => StaticInvoke( DateTimeUtils.getClass, DateType, "fromJavaDate", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case c if c == classOf[java.math.BigDecimal] => StaticInvoke( Decimal.getClass, DecimalType.SYSTEM_DEFAULT, "apply", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case c if c == classOf[java.lang.Boolean] => Invoke(inputObject, "booleanValue", BooleanType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 7bcaea7ea2f7..aff3c3dbb5f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -239,14 +239,16 @@ object ScalaReflection extends ScalaReflection { DateTimeUtils.getClass, ObjectType(classOf[java.sql.Date]), "toJavaDate", - getPath :: Nil) + getPath :: Nil, + returnNullable = false) case t if t <:< localTypeOf[java.sql.Timestamp] => StaticInvoke( DateTimeUtils.getClass, ObjectType(classOf[java.sql.Timestamp]), "toJavaTimestamp", - getPath :: Nil) + getPath :: Nil, + returnNullable = false) case t if t <:< localTypeOf[java.lang.String] => Invoke(getPath, "toString", ObjectType(classOf[String])) @@ -316,7 +318,8 @@ object ScalaReflection extends ScalaReflection { scala.collection.mutable.WrappedArray.getClass, ObjectType(classOf[Seq[_]]), "make", - array :: Nil) + array :: Nil, + returnNullable = false) case t if t <:< localTypeOf[Map[_, _]] => // TODO: add walked type path for map @@ -344,7 +347,8 @@ object ScalaReflection extends ScalaReflection { ArrayBasedMapData.getClass, ObjectType(classOf[Map[_, _]]), "toScalaMap", - keyData :: valueData :: Nil) + keyData :: valueData :: Nil, + returnNullable = false) case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) => val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() @@ -449,7 +453,8 @@ object ScalaReflection extends ScalaReflection { classOf[UnsafeArrayData], ArrayType(dt, false), "fromPrimitiveArray", - input :: Nil) + input :: Nil, + returnNullable = false) } else { NewInstance( classOf[GenericArrayData], @@ -505,49 +510,56 @@ object ScalaReflection extends ScalaReflection { classOf[UTF8String], StringType, "fromString", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = true) case t if t <:< localTypeOf[java.sql.Timestamp] => StaticInvoke( DateTimeUtils.getClass, TimestampType, "fromJavaTimestamp", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case t if t <:< localTypeOf[java.sql.Date] => StaticInvoke( DateTimeUtils.getClass, DateType, "fromJavaDate", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case t if t <:< localTypeOf[BigDecimal] => StaticInvoke( Decimal.getClass, DecimalType.SYSTEM_DEFAULT, "apply", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case t if t <:< localTypeOf[java.math.BigDecimal] => StaticInvoke( Decimal.getClass, DecimalType.SYSTEM_DEFAULT, "apply", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case t if t <:< localTypeOf[java.math.BigInteger] => StaticInvoke( Decimal.getClass, DecimalType.BigIntDecimal, "apply", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case t if t <:< localTypeOf[scala.math.BigInt] => StaticInvoke( Decimal.getClass, DecimalType.BigIntDecimal, "apply", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case t if t <:< localTypeOf[java.lang.Integer] => Invoke(inputObject, "intValue", IntegerType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index e95e97b9dc6c..a157ada2a26b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -96,28 +96,32 @@ object RowEncoder { DateTimeUtils.getClass, TimestampType, "fromJavaTimestamp", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case DateType => StaticInvoke( DateTimeUtils.getClass, DateType, "fromJavaDate", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case d: DecimalType => StaticInvoke( Decimal.getClass, d, "fromDecimal", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case StringType => StaticInvoke( classOf[UTF8String], StringType, "fromString", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = true) case t @ ArrayType(et, cn) => et match { @@ -126,7 +130,8 @@ object RowEncoder { classOf[ArrayData], t, "toArrayData", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case _ => MapObjects( element => serializerFor(ValidateExternalType(element, et), et), inputObject, @@ -252,14 +257,16 @@ object RowEncoder { DateTimeUtils.getClass, ObjectType(classOf[java.sql.Timestamp]), "toJavaTimestamp", - input :: Nil) + input :: Nil, + returnNullable = false) case DateType => StaticInvoke( DateTimeUtils.getClass, ObjectType(classOf[java.sql.Date]), "toJavaDate", - input :: Nil) + input :: Nil, + returnNullable = false) case _: DecimalType => Invoke(input, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal])) @@ -277,7 +284,8 @@ object RowEncoder { scala.collection.mutable.WrappedArray.getClass, ObjectType(classOf[Seq[_]]), "make", - arrayData :: Nil) + arrayData :: Nil, + returnNullable = false) case MapType(kt, vt, valueNullable) => val keyArrayType = ArrayType(kt, false) @@ -290,7 +298,8 @@ object RowEncoder { ArrayBasedMapData.getClass, ObjectType(classOf[Map[_, _]]), "toScalaMap", - keyData :: valueData :: Nil) + keyData :: valueData :: Nil, + returnNullable = false) case schema @ StructType(fields) => val convertedFields = fields.zipWithIndex.map { case (f, i) => From f7b1aa05a64bc8efc43f6e932c1fcb06f18866f7 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 1 Dec 2016 14:08:58 +0900 Subject: [PATCH 3/6] Update a comment to follow #15780. --- .../apache/spark/sql/catalyst/expressions/objects/objects.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 6bf8542c0064..9c92b9e534a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -116,7 +116,7 @@ trait InvokeLike extends Expression with NonSQLExpression { * @param arguments An optional list of expressions to pass as arguments to the function. * @param propagateNull When true, and any of the arguments is null, null will be returned instead * of calling the function. - * @param returnNullable When false, return value must be non-null. + * @param returnNullable When false, indicating the invoked method will return non-null value. */ case class StaticInvoke( staticObject: Class[_], From c83919e6ca75b3ea803a9227dc327cc34ec8e728 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Fri, 2 Dec 2016 13:44:20 +0900 Subject: [PATCH 4/6] Update a comment to follow changes. --- .../spark/sql/catalyst/expressions/objects/objects.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index c4308fcea806..e60be19a420b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -116,7 +116,8 @@ trait InvokeLike extends Expression with NonSQLExpression { * @param arguments An optional list of expressions to pass as arguments to the function. * @param propagateNull When true, and any of the arguments is null, null will be returned instead * of calling the function. - * @param returnNullable When false, indicating the invoked method will return non-null value. + * @param returnNullable When false, indicating the invoked method will always return + * non-null value. */ case class StaticInvoke( staticObject: Class[_], From 63f92b97c8479e7a34a747c38025875d5a485e9b Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 3 Jul 2017 09:41:57 +0900 Subject: [PATCH 5/6] Remove explicit `propagateNull = true`. --- .../sql/catalyst/JavaTypeInference.scala | 3 --- .../spark/sql/catalyst/ScalaReflection.scala | 21 +++++++------------ 2 files changed, 7 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 830d436370e0..b93e53841d24 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -216,7 +216,6 @@ object JavaTypeInference { ObjectType(c), "valueOf", getPath :: Nil, - propagateNull = true, returnNullable = false) case c if c == classOf[java.sql.Date] => @@ -225,7 +224,6 @@ object JavaTypeInference { ObjectType(c), "toJavaDate", getPath :: Nil, - propagateNull = true, returnNullable = false) case c if c == classOf[java.sql.Timestamp] => @@ -234,7 +232,6 @@ object JavaTypeInference { ObjectType(c), "toJavaTimestamp", getPath :: Nil, - propagateNull = true, returnNullable = false) case c if c == classOf[java.lang.String] => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 7749662adac4..5afff183cf3d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -206,44 +206,37 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[java.lang.Integer] => val boxedType = classOf[java.lang.Integer] val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true, - returnNullable = false) + StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false) case t if t <:< localTypeOf[java.lang.Long] => val boxedType = classOf[java.lang.Long] val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true, - returnNullable = false) + StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false) case t if t <:< localTypeOf[java.lang.Double] => val boxedType = classOf[java.lang.Double] val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true, - returnNullable = false) + StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false) case t if t <:< localTypeOf[java.lang.Float] => val boxedType = classOf[java.lang.Float] val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true, - returnNullable = false) + StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false) case t if t <:< localTypeOf[java.lang.Short] => val boxedType = classOf[java.lang.Short] val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true, - returnNullable = false) + StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false) case t if t <:< localTypeOf[java.lang.Byte] => val boxedType = classOf[java.lang.Byte] val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true, - returnNullable = false) + StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false) case t if t <:< localTypeOf[java.lang.Boolean] => val boxedType = classOf[java.lang.Boolean] val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true, - returnNullable = false) + StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false) case t if t <:< localTypeOf[java.sql.Date] => StaticInvoke( From b849b59f03c824be0530565032154f12e5001c66 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 3 Jul 2017 09:42:15 +0900 Subject: [PATCH 6/6] Modify `returnNullable` for `UTF8String.fromString`. --- .../scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala | 2 +- .../scala/org/apache/spark/sql/catalyst/ScalaReflection.scala | 2 +- .../org/apache/spark/sql/catalyst/encoders/RowEncoder.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index b93e53841d24..0593ef7f5c32 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -369,7 +369,7 @@ object JavaTypeInference { StringType, "fromString", inputObject :: Nil, - returnNullable = true) + returnNullable = false) case c if c == classOf[java.sql.Timestamp] => StaticInvoke( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 5afff183cf3d..bf1fce1329a1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -507,7 +507,7 @@ object ScalaReflection extends ScalaReflection { StringType, "fromString", inputObject :: Nil, - returnNullable = true) + returnNullable = false) case t if t <:< localTypeOf[java.sql.Timestamp] => StaticInvoke( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index c7b9f88af719..cc32fac67e92 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -121,7 +121,7 @@ object RowEncoder { StringType, "fromString", inputObject :: Nil, - returnNullable = true) + returnNullable = false) case t @ ArrayType(et, cn) => et match {