Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -216,23 +216,23 @@ object JavaTypeInference {
ObjectType(c),
"valueOf",
getPath :: Nil,
propagateNull = true)
returnNullable = false)

case c if c == classOf[java.sql.Date] =>
StaticInvoke(
DateTimeUtils.getClass,
ObjectType(c),
"toJavaDate",
getPath :: Nil,
propagateNull = true)
returnNullable = false)

case c if c == classOf[java.sql.Timestamp] =>
StaticInvoke(
DateTimeUtils.getClass,
ObjectType(c),
"toJavaTimestamp",
getPath :: Nil,
propagateNull = true)
returnNullable = false)

case c if c == classOf[java.lang.String] =>
Invoke(getPath, "toString", ObjectType(classOf[String]))
Expand Down Expand Up @@ -300,7 +300,8 @@ object JavaTypeInference {
ArrayBasedMapData.getClass,
ObjectType(classOf[JMap[_, _]]),
"toJavaMap",
keyData :: valueData :: Nil)
keyData :: valueData :: Nil,
returnNullable = false)

case other =>
val properties = getJavaBeanReadableAndWritableProperties(other)
Expand Down Expand Up @@ -367,28 +368,32 @@ object JavaTypeInference {
classOf[UTF8String],
StringType,
"fromString",
inputObject :: Nil)
inputObject :: Nil,
returnNullable = false)

case c if c == classOf[java.sql.Timestamp] =>
StaticInvoke(
DateTimeUtils.getClass,
TimestampType,
"fromJavaTimestamp",
inputObject :: Nil)
inputObject :: Nil,
returnNullable = false)

case c if c == classOf[java.sql.Date] =>
StaticInvoke(
DateTimeUtils.getClass,
DateType,
"fromJavaDate",
inputObject :: Nil)
inputObject :: Nil,
returnNullable = false)

case c if c == classOf[java.math.BigDecimal] =>
StaticInvoke(
Decimal.getClass,
DecimalType.SYSTEM_DEFAULT,
"apply",
inputObject :: Nil)
inputObject :: Nil,
returnNullable = false)

case c if c == classOf[java.lang.Boolean] =>
Invoke(inputObject, "booleanValue", BooleanType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,51 +206,53 @@ object ScalaReflection extends ScalaReflection {
case t if t <:< localTypeOf[java.lang.Integer] =>
val boxedType = classOf[java.lang.Integer]
val objectType = ObjectType(boxedType)
StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true)
StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false)

case t if t <:< localTypeOf[java.lang.Long] =>
val boxedType = classOf[java.lang.Long]
val objectType = ObjectType(boxedType)
StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true)
StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false)

case t if t <:< localTypeOf[java.lang.Double] =>
val boxedType = classOf[java.lang.Double]
val objectType = ObjectType(boxedType)
StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true)
StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false)

case t if t <:< localTypeOf[java.lang.Float] =>
val boxedType = classOf[java.lang.Float]
val objectType = ObjectType(boxedType)
StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true)
StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false)

case t if t <:< localTypeOf[java.lang.Short] =>
val boxedType = classOf[java.lang.Short]
val objectType = ObjectType(boxedType)
StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true)
StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false)

case t if t <:< localTypeOf[java.lang.Byte] =>
val boxedType = classOf[java.lang.Byte]
val objectType = ObjectType(boxedType)
StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true)
StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false)

case t if t <:< localTypeOf[java.lang.Boolean] =>
val boxedType = classOf[java.lang.Boolean]
val objectType = ObjectType(boxedType)
StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true)
StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false)

case t if t <:< localTypeOf[java.sql.Date] =>
StaticInvoke(
DateTimeUtils.getClass,
ObjectType(classOf[java.sql.Date]),
"toJavaDate",
getPath :: Nil)
getPath :: Nil,
returnNullable = false)

case t if t <:< localTypeOf[java.sql.Timestamp] =>
StaticInvoke(
DateTimeUtils.getClass,
ObjectType(classOf[java.sql.Timestamp]),
"toJavaTimestamp",
getPath :: Nil)
getPath :: Nil,
returnNullable = false)

case t if t <:< localTypeOf[java.lang.String] =>
Invoke(getPath, "toString", ObjectType(classOf[String]), returnNullable = false)
Expand Down Expand Up @@ -446,7 +448,8 @@ object ScalaReflection extends ScalaReflection {
classOf[UnsafeArrayData],
ArrayType(dt, false),
"fromPrimitiveArray",
input :: Nil)
input :: Nil,
returnNullable = false)
} else {
NewInstance(
classOf[GenericArrayData],
Expand Down Expand Up @@ -503,49 +506,56 @@ object ScalaReflection extends ScalaReflection {
classOf[UTF8String],
StringType,
"fromString",
inputObject :: Nil)
inputObject :: Nil,
returnNullable = false)

case t if t <:< localTypeOf[java.sql.Timestamp] =>
StaticInvoke(
DateTimeUtils.getClass,
TimestampType,
"fromJavaTimestamp",
inputObject :: Nil)
inputObject :: Nil,
returnNullable = false)

case t if t <:< localTypeOf[java.sql.Date] =>
StaticInvoke(
DateTimeUtils.getClass,
DateType,
"fromJavaDate",
inputObject :: Nil)
inputObject :: Nil,
returnNullable = false)

case t if t <:< localTypeOf[BigDecimal] =>
StaticInvoke(
Decimal.getClass,
DecimalType.SYSTEM_DEFAULT,
"apply",
inputObject :: Nil)
inputObject :: Nil,
returnNullable = false)

case t if t <:< localTypeOf[java.math.BigDecimal] =>
StaticInvoke(
Decimal.getClass,
DecimalType.SYSTEM_DEFAULT,
"apply",
inputObject :: Nil)
inputObject :: Nil,
returnNullable = false)

case t if t <:< localTypeOf[java.math.BigInteger] =>
StaticInvoke(
Decimal.getClass,
DecimalType.BigIntDecimal,
"apply",
inputObject :: Nil)
inputObject :: Nil,
returnNullable = false)

case t if t <:< localTypeOf[scala.math.BigInt] =>
StaticInvoke(
Decimal.getClass,
DecimalType.BigIntDecimal,
"apply",
inputObject :: Nil)
inputObject :: Nil,
returnNullable = false)

case t if t <:< localTypeOf[java.lang.Integer] =>
Invoke(inputObject, "intValue", IntegerType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,28 +96,32 @@ object RowEncoder {
DateTimeUtils.getClass,
TimestampType,
"fromJavaTimestamp",
inputObject :: Nil)
inputObject :: Nil,
returnNullable = false)

case DateType =>
StaticInvoke(
DateTimeUtils.getClass,
DateType,
"fromJavaDate",
inputObject :: Nil)
inputObject :: Nil,
returnNullable = false)

case d: DecimalType =>
StaticInvoke(
Decimal.getClass,
d,
"fromDecimal",
inputObject :: Nil)
inputObject :: Nil,
returnNullable = false)

case StringType =>
StaticInvoke(
classOf[UTF8String],
StringType,
"fromString",
inputObject :: Nil)
inputObject :: Nil,
returnNullable = false)

case t @ ArrayType(et, cn) =>
et match {
Expand All @@ -126,7 +130,8 @@ object RowEncoder {
classOf[ArrayData],
t,
"toArrayData",
inputObject :: Nil)
inputObject :: Nil,
returnNullable = false)
case _ => MapObjects(
element => serializerFor(ValidateExternalType(element, et), et),
inputObject,
Expand Down Expand Up @@ -254,14 +259,16 @@ object RowEncoder {
DateTimeUtils.getClass,
ObjectType(classOf[java.sql.Timestamp]),
"toJavaTimestamp",
input :: Nil)
input :: Nil,
returnNullable = false)

case DateType =>
StaticInvoke(
DateTimeUtils.getClass,
ObjectType(classOf[java.sql.Date]),
"toJavaDate",
input :: Nil)
input :: Nil,
returnNullable = false)

case _: DecimalType =>
Invoke(input, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]),
Expand All @@ -280,7 +287,8 @@ object RowEncoder {
scala.collection.mutable.WrappedArray.getClass,
ObjectType(classOf[Seq[_]]),
"make",
arrayData :: Nil)
arrayData :: Nil,
returnNullable = false)

case MapType(kt, vt, valueNullable) =>
val keyArrayType = ArrayType(kt, false)
Expand All @@ -293,7 +301,8 @@ object RowEncoder {
ArrayBasedMapData.getClass,
ObjectType(classOf[Map[_, _]]),
"toScalaMap",
keyData :: valueData :: Nil)
keyData :: valueData :: Nil,
returnNullable = false)

case schema @ StructType(fields) =>
val convertedFields = fields.zipWithIndex.map { case (f, i) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,17 +118,20 @@ trait InvokeLike extends Expression with NonSQLExpression {
* @param arguments An optional list of expressions to pass as arguments to the function.
* @param propagateNull When true, and any of the arguments is null, null will be returned instead
* of calling the function.
* @param returnNullable When false, indicating the invoked method will always return
* non-null value.
*/
case class StaticInvoke(
staticObject: Class[_],
dataType: DataType,
functionName: String,
arguments: Seq[Expression] = Nil,
propagateNull: Boolean = true) extends InvokeLike {
propagateNull: Boolean = true,
returnNullable: Boolean = true) extends InvokeLike {

val objectName = staticObject.getName.stripSuffix("$")

override def nullable: Boolean = true
override def nullable: Boolean = needNullCheck || returnNullable
override def children: Seq[Expression] = arguments

override def eval(input: InternalRow): Any =
Expand All @@ -141,19 +144,40 @@ case class StaticInvoke(

val callFunc = s"$objectName.$functionName($argString)"

// If the function can return null, we do an extra check to make sure our null bit is still set
// correctly.
val postNullCheck = if (ctx.defaultValue(dataType) == "null") {
s"${ev.isNull} = ${ev.value} == null;"
val prepareIsNull = if (nullable) {
s"boolean ${ev.isNull} = $resultIsNull;"
} else {
ev.isNull = "false"
""
}

val evaluate = if (returnNullable) {
if (ctx.defaultValue(dataType) == "null") {
s"""
${ev.value} = $callFunc;
${ev.isNull} = ${ev.value} == null;
"""
} else {
val boxedResult = ctx.freshName("boxedResult")
s"""
${ctx.boxedType(dataType)} $boxedResult = $callFunc;
${ev.isNull} = $boxedResult == null;
if (!${ev.isNull}) {
${ev.value} = $boxedResult;
}
"""
}
} else {
s"${ev.value} = $callFunc;"
}

val code = s"""
$argCode
boolean ${ev.isNull} = $resultIsNull;
final $javaType ${ev.value} = $resultIsNull ? ${ctx.defaultValue(dataType)} : $callFunc;
$postNullCheck
$prepareIsNull
$javaType ${ev.value} = ${ctx.defaultValue(dataType)};
if (!$resultIsNull) {
$evaluate
}
"""
ev.copy(code = code)
}
Expand Down