diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index f4ecbdb8393a..b8d3661a00ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -548,8 +548,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast)) } - // three function arguments are: child.primitive, result.primitive and result.isNull - // it returns the code snippets to be put in null safe evaluation region + // The function arguments are: `input`, `result` and `resultIsNull`. We don't need `inputIsNull` + // in parameter list, because the returned code will be put in null safe evaluation region. private[this] type CastFunction = (String, String, String) => String private[this] def nullSafeCastFunction( @@ -584,15 +584,15 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String throw new SparkException(s"Cannot cast $from to $to.") } - // Since we need to cast child expressions recursively inside ComplexTypes, such as Map's + // Since we need to cast input expressions recursively inside ComplexTypes, such as Map's // Key and Value, Struct's field, we need to name out all the variable names involved in a cast. - private[this] def castCode(ctx: CodegenContext, childPrim: String, childNull: String, - resultPrim: String, resultNull: String, resultType: DataType, cast: CastFunction): String = { + private[this] def castCode(ctx: CodegenContext, input: String, inputIsNull: String, + result: String, resultIsNull: String, resultType: DataType, cast: CastFunction): String = { s""" - boolean $resultNull = $childNull; - ${ctx.javaType(resultType)} $resultPrim = ${ctx.defaultValue(resultType)}; - if (!$childNull) { - ${cast(childPrim, resultPrim, resultNull)} + boolean $resultIsNull = $inputIsNull; + ${ctx.javaType(resultType)} $result = ${ctx.defaultValue(resultType)}; + if (!$inputIsNull) { + ${cast(input, result, resultIsNull)} } """ } @@ -1014,8 +1014,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String case (fromField, toField) => nullSafeCastFunction(fromField.dataType, toField.dataType, ctx) } val rowClass = classOf[GenericInternalRow].getName - val result = ctx.freshName("result") - val tmpRow = ctx.freshName("tmpRow") + val tmpResult = ctx.freshName("tmpResult") + val tmpInput = ctx.freshName("tmpInput") val fieldsEvalCode = fieldsCasts.zipWithIndex.map { case (cast, i) => val fromFieldPrim = ctx.freshName("ffp") @@ -1024,37 +1024,33 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val toFieldNull = ctx.freshName("tfn") val fromType = ctx.javaType(from.fields(i).dataType) s""" - boolean $fromFieldNull = $tmpRow.isNullAt($i); + boolean $fromFieldNull = $tmpInput.isNullAt($i); if ($fromFieldNull) { - $result.setNullAt($i); + $tmpResult.setNullAt($i); } else { $fromType $fromFieldPrim = - ${ctx.getValue(tmpRow, from.fields(i).dataType, i.toString)}; + ${ctx.getValue(tmpInput, from.fields(i).dataType, i.toString)}; ${castCode(ctx, fromFieldPrim, fromFieldNull, toFieldPrim, toFieldNull, to.fields(i).dataType, cast)} if ($toFieldNull) { - $result.setNullAt($i); + $tmpResult.setNullAt($i); } else { - ${ctx.setColumn(result, to.fields(i).dataType, i, toFieldPrim)}; + ${ctx.setColumn(tmpResult, to.fields(i).dataType, i, toFieldPrim)}; } } """ } - val fieldsEvalCodes = if (ctx.currentVars == null) { - ctx.splitExpressions( - expressions = fieldsEvalCode, - funcName = "castStruct", - arguments = ("InternalRow", tmpRow) :: (rowClass, result) :: Nil) - } else { - fieldsEvalCode.mkString("\n") - } + val fieldsEvalCodes = ctx.splitExpressions( + expressions = fieldsEvalCode, + funcName = "castStruct", + arguments = ("InternalRow", tmpInput) :: (rowClass, tmpResult) :: Nil) - (c, evPrim, evNull) => + (input, result, resultIsNull) => s""" - final $rowClass $result = new $rowClass(${fieldsCasts.length}); - final InternalRow $tmpRow = $c; + final $rowClass $tmpResult = new $rowClass(${fieldsCasts.length}); + final InternalRow $tmpInput = $input; $fieldsEvalCodes - $evPrim = $result; + $result = $tmpResult; """ }