Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -548,8 +548,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast))
}

// three function arguments are: child.primitive, result.primitive and result.isNull
// it returns the code snippets to be put in null safe evaluation region
// The function arguments are: `input`, `result` and `resultIsNull`. We don't need `inputIsNull`
// in parameter list, because the returned code will be put in null safe evaluation region.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some renaming to make it more readable.

private[this] type CastFunction = (String, String, String) => String

private[this] def nullSafeCastFunction(
Expand Down Expand Up @@ -584,15 +584,15 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
throw new SparkException(s"Cannot cast $from to $to.")
}

// Since we need to cast child expressions recursively inside ComplexTypes, such as Map's
// Since we need to cast input expressions recursively inside ComplexTypes, such as Map's
// Key and Value, Struct's field, we need to name out all the variable names involved in a cast.
private[this] def castCode(ctx: CodegenContext, childPrim: String, childNull: String,
resultPrim: String, resultNull: String, resultType: DataType, cast: CastFunction): String = {
private[this] def castCode(ctx: CodegenContext, input: String, inputIsNull: String,
result: String, resultIsNull: String, resultType: DataType, cast: CastFunction): String = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indents.

s"""
boolean $resultNull = $childNull;
${ctx.javaType(resultType)} $resultPrim = ${ctx.defaultValue(resultType)};
if (!$childNull) {
${cast(childPrim, resultPrim, resultNull)}
boolean $resultIsNull = $inputIsNull;
${ctx.javaType(resultType)} $result = ${ctx.defaultValue(resultType)};
if (!$inputIsNull) {
${cast(input, result, resultIsNull)}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some renaming to make it more readable.

}
"""
}
Expand Down Expand Up @@ -1014,8 +1014,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
case (fromField, toField) => nullSafeCastFunction(fromField.dataType, toField.dataType, ctx)
}
val rowClass = classOf[GenericInternalRow].getName
val result = ctx.freshName("result")
val tmpRow = ctx.freshName("tmpRow")
val tmpResult = ctx.freshName("tmpResult")
val tmpInput = ctx.freshName("tmpInput")

val fieldsEvalCode = fieldsCasts.zipWithIndex.map { case (cast, i) =>
val fromFieldPrim = ctx.freshName("ffp")
Expand All @@ -1024,37 +1024,33 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
val toFieldNull = ctx.freshName("tfn")
val fromType = ctx.javaType(from.fields(i).dataType)
s"""
boolean $fromFieldNull = $tmpRow.isNullAt($i);
boolean $fromFieldNull = $tmpInput.isNullAt($i);
if ($fromFieldNull) {
$result.setNullAt($i);
$tmpResult.setNullAt($i);
} else {
$fromType $fromFieldPrim =
${ctx.getValue(tmpRow, from.fields(i).dataType, i.toString)};
${ctx.getValue(tmpInput, from.fields(i).dataType, i.toString)};
${castCode(ctx, fromFieldPrim,
fromFieldNull, toFieldPrim, toFieldNull, to.fields(i).dataType, cast)}
if ($toFieldNull) {
$result.setNullAt($i);
$tmpResult.setNullAt($i);
} else {
${ctx.setColumn(result, to.fields(i).dataType, i, toFieldPrim)};
${ctx.setColumn(tmpResult, to.fields(i).dataType, i, toFieldPrim)};
}
}
"""
}
val fieldsEvalCodes = if (ctx.currentVars == null) {
ctx.splitExpressions(
expressions = fieldsEvalCode,
funcName = "castStruct",
arguments = ("InternalRow", tmpRow) :: (rowClass, result) :: Nil)
} else {
fieldsEvalCode.mkString("\n")
}
val fieldsEvalCodes = ctx.splitExpressions(
expressions = fieldsEvalCode,
funcName = "castStruct",
arguments = ("InternalRow", tmpInput) :: (rowClass, tmpResult) :: Nil)

(c, evPrim, evNull) =>
(input, result, resultIsNull) =>
s"""
final $rowClass $result = new $rowClass(${fieldsCasts.length});
final InternalRow $tmpRow = $c;
final $rowClass $tmpResult = new $rowClass(${fieldsCasts.length});
final InternalRow $tmpInput = $input;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tmpInput and tmpResult are the only inputs we need for the generated code to cast struct, and we don't depend on ctx.INPUT_ROW and ctx.currentVars here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in another word, the code to cast a struct is always row-based, the input is a variable of type InternalRow. We don't care about ctx.INPUT_ROW and ctx.currentVars here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, now I see! Thanks for the kind explanation.

$fieldsEvalCodes
$evPrim = $result;
$result = $tmpResult;
"""
}

Expand Down